diff --git a/.bazelrc b/.bazelrc index 1dd928acdb4..1b9f5e87c6b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -18,8 +18,10 @@ # # Compiler options: # cuda_clang: Use clang when building CUDA code. -# c++17: Build with C++17 options -# c++1z: Build with C++17 options +# c++17: Build with C++17 options (links with libc++) +# c++1z: Build with C++17 options (links with libc++) +# c++17_gcc: Build with C++17 options (links with stdlibc++) +# c++1z_gcc: Build with C++17 options (links with stdlibc++) # avx_linux: Build with avx instruction set on linux. # avx2_linux: Build with avx2 instruction set on linux. # native_arch_linux: Build with instruction sets available to the host machine on linux @@ -28,6 +30,7 @@ # # Other build options: # short_logs: Only log errors during build, skip warnings. +# verbose_logs: Show all compiler warnings during build. # monolithic: Build all TF C++ code into a single shared object. # dynamic_kernels: Try to link all kernels dynamically (experimental). # libc++: Link against libc++ instead of stdlibc++ @@ -78,7 +81,16 @@ # elinux: General Embedded Linux options shared by all flavors. # elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support. # elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support. - +# +# Release build options (for all operating systems) +# release_common: Common options for all builds on all operating systems. +# release_windows_common: Common options for all builds on Windows. +# release_gpu_common: Common options for GPU builds on Linux and Windows. +# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds. +# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds. +# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds. +# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds. +# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds. # Allow builds using libc++ as a linker library # This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file @@ -155,14 +167,29 @@ build:mkl -c opt # config to build OneDNN backend with a user specified threadpool. build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl_threadpool --define=build_with_mkl_dnn_v1_only=true +build:mkl_threadpool --define=build_with_mkl_opensource=true build:mkl_threadpool --define=build_with_mkldnn_threadpool=true build:mkl_threadpool -c opt + +# Config setting to build with oneDNN and without the binary blob +build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true +build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0 +build:mkl_opensource_only --define=build_with_mkl_dnn_v1_only=true +build:mkl_opensource_only --define=build_with_mkl_opensource=true +build:mkl_opensource_only -c opt + # This config refers to building with CUDA available. It does not necessarily # mean that we build CUDA op kernels. build:using_cuda --define=using_cuda=true build:using_cuda --action_env TF_NEED_CUDA=1 build:using_cuda --crosstool_top=@local_config_cuda//crosstool:toolchain +# Enable the mlir generated GPU kernels only for cuda builds. +build --define=tensorflow_enable_mlir_generated_gpu_kernels=0 +# This is a more specific option, so it takes precedence over the line above for cuda builds. +build:using_cuda --define=tensorflow_enable_mlir_generated_gpu_kernels=1 + # This config refers to building CUDA op kernels with nvcc. build:cuda --config=using_cuda build:cuda --define=using_cuda_nvcc=true @@ -253,6 +280,8 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS build:c++17 --cxxopt=-std=c++1z build:c++17 --cxxopt=-stdlib=libc++ build:c++1z --config=c++17 +build:c++17_gcc --cxxopt=-std=c++1z +build:c++1z_gcc --config=c++17_gcc # Enable using platform specific build settings, except when cross-compiling for # mobile platforms. @@ -322,6 +351,8 @@ build:windows --distinct_host_configuration=false # Suppress all warning messages. build:short_logs --output_filter=DONT_MATCH_ANYTHING +build:verbose_logs --output_filter= +build --config=short_logs # Instruction set optimizations # TODO(gunan): Create a feature in toolchains for avx/avx2 to @@ -341,7 +372,6 @@ build --config=v2 test --config=v2 # Enable XLA -build:xla --action_env=TF_ENABLE_XLA=1 build:xla --define=with_xla_support=true # BEGIN TF REMOTE BUILD EXECUTION OPTIONS @@ -534,3 +564,43 @@ try-import %workspace%/.tf_configure.bazelrc # Put user-specific options in .bazelrc.user try-import %workspace%/.bazelrc.user + +# Here are bazelrc configs for release builds +build:release_common --config=opt +build:release_common --config=v2 +build:release_common --distinct_host_configuration=false +build:release_common --action_env TF_CONFIGURE_IOS="0" + +build:release_cpu_linux --config=release_common +build:release_cpu_linux --config=avx_linux +# We use the same toolchain for CPU/GPU packages. +# Did not add this to the defaults in case this changes. +build:release_cpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain + +build:release_cpu_macos --config=release_common +build:release_cpu_macos --config=avx_linux + +build:release_gpu_common --config=release_common +build:release_gpu_common --config=cuda +build:release_gpu_common --config=tensorrt +build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1" +build:release_gpu_common --action_env=TF_CUDA_VERSION="10" +build:release_gpu_common --action_env=TF_CUDNN_VERSION="7" +build:release_gpu_common --action_env=TF_NEED_TENSORRT="1" +build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70" +build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt" +build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" +build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5" + + +build:release_gpu_linux --config=release_gpu_common +build:release_gpu_linux --config=avx_linux +build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain + +build:release_windows_common --config=release_common +build:release_windows_common --define=no_tensorflow_py_deps=true +build:release_windows_common --announce_rc + +build:release_cpu_windows --config=release_windows_common + +build:release_gpu_windows --config=release_windows_common diff --git a/README.md b/README.md index 9cf595bbf61..6398e8e27a1 100644 --- a/README.md +++ b/README.md @@ -123,20 +123,21 @@ Build Type | Status ### Community Supported Builds -Build Type | Status | Artifacts ------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- -**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/) -**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/) -**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) -**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) -**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) -**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/) -**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) -**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/) -**Linux aarch64 CPU** Nightly
Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) -**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) -**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/) -**Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/) +Build Type | Status | Artifacts +----------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------- +**Linux AMD ROCm GPU** Nightly | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly) | [Nightly](http://ml-ci.amd.com:21096/job/tensorflow-rocm-nightly/lastSuccessfulBuild/) +**Linux AMD ROCm GPU** Stable Release | [![Build Status](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/badge/icon)](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/) | Release [1.15](http://ml-ci.amd.com:21096/job/tensorflow-rocm-release/lastSuccessfulBuild/) / [2.x](http://ml-ci.amd.com:21096/job/tensorflow-rocm-v2-release/lastSuccessfulBuild/) +**Linux s390x** Nightly | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/badge/icon)](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) | [Nightly](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_CI/) +**Linux s390x CPU** Stable Release | [![Build Status](http://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/badge/icon)](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) | [Release](https://ibmz-ci.osuosl.org/job/TensorFlow_IBMZ_Release_Build/) +**Linux ppc64le CPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/) +**Linux ppc64le CPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_CPU_Release_Build/) +**Linux ppc64le GPU** Nightly | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Build/) | [Nightly](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/) +**Linux ppc64le GPU** Stable Release | [![Build Status](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/badge/icon)](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) | Release [1.15](https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/) / [2.x](https://powerci.osuosl.org/job/TensorFlow2_PPC64LE_GPU_Release_Build/) +**Linux aarch64 CPU** Nightly
Python 3.6 | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow)](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) | [Nightly](https://status.openlabtesting.org/builds/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-arm64-build-daily-master) +**Linux aarch64 CPU** Stable Release | [![Build Status](http://openlabtesting.org:15000/badge?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show)](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) | Release [1.15](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v1.15.3-cpu-arm64-release-build-show) / [2.x](http://status.openlabtesting.org/builds?project=tensorflow%2Ftensorflow&job_name=tensorflow-v2.1.0-cpu-arm64-release-build-show) +**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Nightly | [![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/badge/icon)](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) | [Nightly](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-whl-nightly/) +**Linux CPU with Intel oneAPI Deep Neural Network Library (oneDNN)** Stable Release | ![Build Status](https://tensorflow-ci.intel.com/job/tensorflow-mkl-build-release-whl/badge/icon) | Release [1.15](https://pypi.org/project/intel-tensorflow/1.15.0/) / [2.x](https://pypi.org/project/intel-tensorflow/) +**Red Hat® Enterprise Linux® 7.6 CPU & GPU**
Python 2.7, 3.6 | [![Build Status](https://jenkins-tensorflow.apps.ci.centos.org/buildStatus/icon?job=tensorflow-rhel7-3.6&build=2)](https://jenkins-tensorflow.apps.ci.centos.org/job/tensorflow-rhel7-3.6/2/) | [1.13.1 PyPI](https://tensorflow.pypi.thoth-station.ninja/index/) ## Resources diff --git a/RELEASE.md b/RELEASE.md index 69eca82c5f2..430e1b83885 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,10 +11,28 @@ * C-API functions `TF_StringDecode`, `TF_StringEncode`, and `TF_StringEncodedSize` are no longer relevant and have been removed; see core/platform/ctstring.h for string access/modification in C. -* In batching library, rename parameter - SharedBatchScheduler::QueueOptions::max_batch_size to a more accurate name - (input_batch_size_limit) for a recent feature to enable split of large batch - sizes. +* Removed `tf.distribute.Strategy.experimental_run_v2` method, which was deprecated in TF 2.2. +* `tensorflow.python`, `tensorflow.core` and `tensorflow.compiler` modules are + now hidden. These modules are not part of TensorFlow public API. +* A major refactoring of the internals of the Keras Functional API may affect code that is relying on certain internal details: + * Code that uses `isinstance(x, tf.Tensor)` instead of `tf.is_tensor` when checking Keras symbolic inputs/outputs should switch to using `tf.is_tensor`. + * Code that is overly dependent on the exact names attached to symbolic tensors (e.g. assumes there will be ":0" at the end of the inputs, treats names as unique identifiers instead of using `tensor.ref()`, etc.) + * Code that uses `get_concrete_function` to trace Keras symbolic inputs directly should switch to building matching `tf.TensorSpec`s directly and tracing the `TensorSpec` objects. + * Code that relies on the exact number and names of the op layers that TensorFlow operations were converted into. These may have changed. + * Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers and happens to work before TF 2.4. These will explicitly be unsupported now. Converting these ops to Functional API op layers was unreliable before TF 2.4, and prone to erroring incomprehensibly or being silently buggy. + * Code that directly asserts on a Keras symbolic value in cases where ops like `tf.rank` used to return a static or symbolic value depending on if the input had a fully static shape or not. Now these ops always return symbolic values. + * Code already susceptible to leaking tensors outside of graphs becomes slightly more likely to do so now. + * Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient. + * Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know. + * Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed. +* Start enforcing input shape assumptions when calling Functional API Keras + models. This may potentially break some users, in case there is a mismatch + between the shape used when creating `Input` objects in a Functional model, + and the shape of the data passed to that model. You can fix this mismatch by + either calling the model with correctly-shaped data, or by relaxing `Input` + shape assumptions (note that you can pass shapes with `None` entries for axes + that are meant to be dynamic). You can also disable the input checking + entirely by setting `model.input_spec = None`. ## Known Caveats @@ -24,6 +42,8 @@ * * +* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See tensorflow/python/ops/numpy_ops/README.md for details of what are supported and what are the differences with NumPy. +* A major refactoring of the internals of the Keras Functional API has been completed, that should improve the reliability, stability, and performance of constructing Functional models. ## Bug Fixes and Other Changes @@ -31,36 +51,106 @@ * * * TF Core: - * - * `tf.Tensor` is now a subclass of `typing.Generic`, allowing type annotations - to be parameterized by dtype: `tf.Tensor[tf.Int32]`. This requires Python 3, - and will become fully compatible with static type checkers in the future. - + * `tf.types.experimental.TensorLike` is a new `Union` type that can be used as + type annotation for variables representing a Tensor or a value that can be + converted to Tensor by `tf.convert_to_tensor`. + * Calling ops with a python constants or numpy values is now consistent with + tf.convert_to_tensor behavior. This avoids operations like tf.reshape + truncating inputs such as from int64 to int32. + * Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments. + * The Python bitwise operators for `Tensor` (`__and__`, `__or__`, `__xor__` + and `__invert__` now support non-`bool` arguments and apply the + corresponding bitwise ops. `bool` arguments continue to be supported and + dispatch to logical ops. This brings them more in line with Python and NumPy + benavior. + * Added `tf.SparseTensor.with_values`. This returns a new SparseTensor with + the same sparsity pattern, but with new provided values. It is similar to + the `with_values` function of `RaggedTensor`. + * Added `StatelessCase` op, and uses it if none of case branches has stateful ops. * `tf.data`: + * Added new `tf.data.experimental.service.register_dataset` and + `tf.data.experimental.service.from_dataset_id` APIs to enable one process + to register a dataset with the tf.data service, and another process to + consume data from the dataset. + * Added support for tf.data service dispatcher fault tolerance. To enable + fault tolerance, configure a `work_dir` when running your dispatcher + server and set `dispatcher_fault_tolerance=True`. The dispatcher will + store its state to `work_dir`, so that on restart it can continue from its + previous state after restart. * Added optional `exclude_cols` parameter to CsvDataset. This parameter is - the complement of `select_cols`; at most one of these should be specified. + the complement of `select_cols`; at most one of these should be specified. + * We have implemented an optimization which reorders data-discarding + transformations such as `take` and `shard` to happen earlier in the + dataset when it is safe to do so. The optimization can be disabled via + the `experimental_optimization.reorder_data_discarding_ops` dataset + option. +* `tf.image`: + * Added deterministic `tf.image.stateless_random_*` functions for each + `tf.image.random_*` function. Added a new op + `stateless_sample_distorted_bounding_box` which is a determinstic + version of `sample_distorted_bounding_box` op. Given the same seed, these + stateless functions/ops produce the same results independent of how many + times the function is called, and independent of global seed settings. * `tf.distribute`: * -* `tf.keras`: - * -* `tf.function`/AutoGraph: - * +* `tf.keras`: + * Improvements from the functional API refactoring: + * Functional model construction does not need to maintain a global workspace graph, removing memory leaks especially when building many models or very large models. + * Functional model construction should be ~8-10% faster on average. + * Functional models can now contain non-symbolic values in their call inputs inside of the first positional argument. + * Several classes of TF ops that were not reliably converted to Keras layers during functional API construction should now work, e.g. `tf.image.ssim_multiscale` + * Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand. + * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` + as an alternative to accepting a `callable` loss. + * Added `beta` parameter to FTRL optimizer to match paper. + * Added `mobilenet_v3` to keras application model. +* `tf.function` / AutoGraph: + * Added `experimental_follow_type_hints` argument for `tf.function`. When + True, the function may use type annotations to optimize the tracing + performance. + * Added support for `iter(DistributedDataset)` in AutoGraph `for` loops. + * AutoGraph now allows creating new symbols inside a TensorFLow loop, if + the values of these symbols at an iteration does not depend on the previous + iteration. These types of loops must run at least one iteration, and will + raise a runtime error otherwise. + + Example: + + ``` + for batch in data: + outputs = train_step(batch) + tf.print('final outputs', outputs) + ``` + See tensorflow/python/autograph/g3doc/reference/limitations.md for more + info. * `tf.lite`: + * `DynamicBuffer::AddJoinedString()` will now add a separator if the first + string to be joined is empty. + * `TFLiteConverter`: + * Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`). + * Deprecate `Interpreter::UseNNAPI(bool)` C++ API + * Prefer using `NnApiDelegate()` and related delegate configuration methods directly. + * Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair. * * `tf.random`: * * Math and Linear Algebra: * * TPU Enhancements: + * Added support for the `beta` parameter of the FTRL optimizer for TPU + embeddings. Users of other TensorFlow platforms can implement equivalent + behavior by adjusting the `l2` parameter. * * XLA Support: + * xla.experimental.compile is deprecated, use + `tf.function(experimental_compile=True)` instead * * Tracing and Debugging: * * Other: - * We have replaced uses of "whitelist" with "allowlist" where possible. - Please see https://developers.google.com/style/word-list#blacklist for more - context. + * We have replaced uses of "whitelist" and "blacklist" with "allowlist" + and "denylist" where possible. Please see + https://developers.google.com/style/word-list#blacklist for more context. * ## Thanks to our Contributors @@ -71,19 +161,206 @@ stjohnso98, , , , , # Release 2.3.0 -## Breaking Changes +## Major Features and Improvements + * `tf.data` adds two new mechanisms to solve input pipeline bottlenecks and save resources: + * [snapshot](https://www.tensorflow.org/api_docs/python/tf/data/experimental/snapshot) + * [tf.data service](https://www.tensorflow.org/api_docs/python/tf/data/experimental/service). -* `tf.image.extract_glimpse` has been updated to correctly process the case - where `centered=False` and `normalized=False`. This is a breaking change as - the output is different from (incorrect) previous versions. Note this - breaking change only impacts `tf.image.extract_glimpse` and - `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of - `tf.compat.v1.image.extract_glimpse` does not change. The behavior of - exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved - models will not be impacted. + In addition checkout the detailed [guide](https://www.tensorflow.org/guide/data_performance_analysis) for analyzing input pipeline performance with TF Profiler. + + * [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is now a stable API and no longer considered experimental for TensorFlow. (earlier `tf.distribute.experimental.TPUStrategy`). + + * [TF Profiler](https://www.tensorflow.org/guide/profiler) introduces two new tools: a memory profiler to visualize your model’s memory usage over time and a [python tracer](https://www.tensorflow.org/guide/profiler#events) which allows you to trace python function calls in your model. Usability improvements include better diagnostic messages and [profile options](https://tensorflow.org/guide/profiler#collect_performance_data) to customize the host and device trace verbosity level. + + * Introduces experimental support for Keras Preprocessing Layers API ([`tf.keras.layers.experimental.preprocessing.*`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing?version=nightly)) to handle data preprocessing operations, with support for composite tensor inputs. Please see below for additional details on these layers. + + * TFLite now properly supports dynamic shapes during conversion and inference. We’ve also added opt-in support on Android and iOS for [XNNPACK](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/delegates/xnnpack), a highly optimized set of CPU kernels, as well as opt-in support for [executing quantized models on the GPU](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/performance/gpu_advanced.md#running-quantized-models-experimental). + + * Libtensorflow packages are available in GCS starting this release. We have also started to [release a nightly version of these packages](https://github.com/tensorflow/tensorflow#official-builds). + + * The experimental Python API [`tf.debugging.experimental.enable_dump_debug_info()`](https://www.tensorflow.org/api_docs/python/tf/debugging/experimental/enable_dump_debug_info) now allows you to instrument a TensorFlow program and dump debugging information to a directory on the file system. The directory can be read and visualized by a new interactive dashboard in TensorBoard 2.3 called [Debugger V2](https://www.tensorflow.org/tensorboard/debugger_v2), which reveals the details of the TensorFlow program including graph structures, history of op executions at the Python (eager) and intra-graph levels, the runtime dtype, shape, and numerical composistion of tensors, as well as their code locations. + +## Breaking Changes +* Increases the **minimum bazel version** required to build TF to **3.1.0**. +* `tf.data` + * Makes the following (breaking) changes to the `tf.data`. + * C++ API: - `IteratorBase::RestoreInternal`, `IteratorBase::SaveInternal`, and `DatasetBase::CheckExternalState` become pure-virtual and subclasses are now expected to provide an implementation. + * The deprecated `DatasetBase::IsStateful` method is removed in favor of `DatasetBase::CheckExternalState`. + * Deprecated overrides of `DatasetBase::MakeIterator` and `MakeIteratorFromInputElement` are removed. + * The signature of `tensorflow::data::IteratorBase::SaveInternal` and `tensorflow::data::IteratorBase::SaveInput` has been extended with `SerializationContext` argument to enable overriding the default policy for the handling external state during iterator checkpointing. This is not a backwards compatible change and all subclasses of `IteratorBase` *need to be updated* accordingly. +* `tf.keras` + * Add a new `BackupAndRestore` callback for handling distributed training failures & restarts. Please take a look at this [tutorial](https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras) for details on how to use the callback. +* `tf.image.extract_glimpse` has been updated to correctly process the case + where `centered=False` and `normalized=False`. This is a breaking change as + the output is different from (incorrect) previous versions. Note this + breaking change only impacts `tf.image.extract_glimpse` and + `tf.compat.v2.image.extract_glimpse` API endpoints. The behavior of + `tf.compat.v1.image.extract_glimpse` does not change. The behavior of + exsiting C++ kernel `ExtractGlimpse` does not change either, so saved + models using `tf.raw_ops.ExtractGlimpse` will not be impacted. + +## Known Caveats + * `tf.lite` + * Keras-based LSTM models must be converted with an explicit batch size in the input layer. ## Bug Fixes and Other Changes -* Mutable tables now restore checkpointed values when loaded from SavedModel. + +### TF Core: + * Set `tf2_behavior` to 1 to enable V2 for early loading cases. + * Add `execute_fn_for_device function` to dynamically choose the implementation based on underlying device placement. + * Eager: + * Add `reduce_logsumexp` benchmark with experiment compile. + * Give `EagerTensor`s a meaningful `__array__` implementation. + * Add another version of defun matmul for performance analysis. + * `tf.function`/AutoGraph: + * `AutoGraph` now includes into TensorFlow loops any variables that are closed over by local functions. Previously, such variables were sometimes incorrectly ignored. + * functions returned by the `get_concrete_function` method of `tf.function` objects can now be called with arguments consistent with the original arguments or type specs passed to `get_concrete_function`. This calling convention is now the preferred way to use concrete functions with nested values and composite tensors. Please check the [guide](https://www.tensorflow.org/guide/concrete_function) for more details on `concrete_ function`. + * Update `tf.function`'s `experimental_relax_shapes` to handle composite tensors appropriately. + * Optimize `tf.function` invocation, by removing redundant list converter. + * `tf.function` will retrace when called with a different variable instead of simply using the `dtype` & `shape`. + * [Improve support](https://github.com/tensorflow/tensorflow/issues/33862) for dynamically-sized TensorArray inside `tf.function`. + * `tf.math`: + * Narrow down `argmin`/`argmax` contract to always return the smallest index for ties. + * `tf.math.reduce_variance` and `tf.math.reduce_std` return correct computation for complex types and no longer support integer types. + * Add Bessel functions of order 0,1 to `tf.math.special`. + * `tf.divide` now always returns a tensor to be consistent with documentation and other APIs. + * `tf.image`: + * Replaced [`tf.image.non_max_suppression_padded`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/image/non_max_suppression_padded?hl=en) with a new implementation that supports batched inputs, which is considerably faster on TPUs and GPUs. Boxes with area=0 will be ignored. Existing usage with single inputs should still work as before. + * `tf.linalg` + * Add `tf.linalg.banded_triangular_solve`. + * `tf.random`: + * Add `tf.random.stateless_parameterized_truncated_normal`. + * `tf.ragged`: + * Add `tf.ragged.cross` and `tf.ragged.cross_hashed` operations. + * `tf.RaggedTensor`: + * `RaggedTensor.to_tensor()` now preserves static shape. + * Add `tf.strings.format()` and `tf.print()` to support RaggedTensors. + * `tf.saved_model`: + * `@tf.function` from SavedModel no longer ignores args after a `RaggedTensor` when selecting the concrete function to run. + * Fix save model issue for ops with a list of functions. + * Add `tf.saved_model.LoadOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/LoadOptions?hl=en) as arg with default value `None` to choose the I/O device for loading models and weights. + * Update `tf.saved_model.SaveOptions` with [`experimental_io_device`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/saved_model/SaveOptions?hl=en) as arg with default value `None` to choose the I/O device for saving models and weights. + * Mutable tables now restore checkpointed values when loaded from SavedModel. + * GPU + * TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities. + * Others + * Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`. + * Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations. + * `tf.custom_gradient` can now be applied to functions that accept nested structures of `tensors` as inputs (instead of just a list of tensors). Note that Python structures such as tuples and lists now won't be treated as tensors, so if you still want them to be treated that way, you need to wrap them with `tf.convert_to_tensor`. + * No lowering on gradient case op when input is `DeviceIndex` op. + * Extend the ragged version of `tf.gather` to support `batch_dims` and `axis` args. + * Update `tf.map_fn` to support RaggedTensors and SparseTensors. + * Deprecate `tf.group`. It is not useful in eager mode. + * Add CPU and GPU implementation of modified variation of [`FTRL`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/raw_ops/ApplyFtrl)/[`FTRLV2`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/raw_ops/ApplyFtrlV2) that can triggerred by `multiply_linear_by_lr` allowing a learning rate of zero. + +### `tf.data`: + * `tf.data.experimental.dense_to_ragged_batch` works correctly with tuples. + * `tf.data.experimental.dense_to_ragged_batch` to output variable ragged rank. + * `tf.data.experimental.cardinality` is now a method on `tf.data.Dataset`. + * `tf.data.Dataset` now supports `len(Dataset)` when the cardinality is finite. + +### `tf.distribute`: + * Expose experimental [`tf.distribute.DistributedDataset`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedDataset?hl=en) and [`tf.distribute.DistributedIterator`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator) to distribute input data when using `tf.distribute` to scale training on multiple devices. + * Added a [`get_next_as_optional`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator?hl=en#get_next_as_optional) method for [`tf.distribute.DistributedIterator`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/distribute/DistributedIterator?hl=en) class to return a `tf.experimental.Optional` instance that contains the next value for all replicas or none instead of raising an out of range error. Also see *new* [guide on input distribution](https://www.tensorflow.org/tutorials/distribute/input). + * Allow var.assign on MirroredVariables with aggregation=NONE in replica context. Previously this would raise an error. We now allow this because many users and library writers find using `.assign` in replica context to be more convenient, instead of having to use `Strategy.extended.update` which was the previous way of updating variables in this situation. + * `tf.distribute.experimental.MultiWorkerMirroredStrategy` adds support for partial batches. Workers running out of data now continue to participate in the training with empty inputs, instead of raising an error. Learn more about [partial batches here](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). + * Improve the performance of reading metrics eagerly under `tf.distribute.experimental.MultiWorkerMirroredStrategy`. + * Fix the issue that `strategy.reduce()` inside `tf.function` may raise exceptions when the values to reduce are from loops or if-clauses. + * Fix the issue that `tf.distribute.MirroredStrategy` cannot be used together with `tf.distribute.experimental.MultiWorkerMirroredStrategy`. + * Add a `tf.distribute.cluster_resolver.TPUClusterResolver.connect` API to simplify TPU initialization. + +### `tf.keras`: + * Introduces experimental preprocessing layers API (`tf.keras.layers.experimental.preprocessing`) to handle data preprocessing operations such as categorical feature encoding, text vectorization, data normalization, and data discretization (binning). The newly added layers provide a replacement for the legacy feature column API, and support composite tensor inputs. + * Added **categorical data** processing layers: + * `IntegerLookup` & `StringLookup`: build an index of categorical feature values + * `CategoryEncoding`: turn integer-encoded categories into one-hot, multi-hot, or tf-idf encoded representations + * `CategoryCrossing`: create new categorical features representing co-occurrences of previous categorical feature values + * `Hashing`: the hashing trick, for large-vocabulary categorical features + * `Discretization`: turn continuous numerical features into categorical features by binning their values + * Improved **image preprocessing** layers: `CenterCrop`, `Rescaling` + * Improved **image augmentation** layers: `RandomCrop`, `RandomFlip`, `RandomTranslation`, `RandomRotation`, `RandomHeight`, `RandomWidth`, `RandomZoom`, `RandomContrast` + * Improved **`TextVectorization`** layer, which handles string tokenization, n-gram generation, and token encoding + * The `TextVectorization` layer now accounts for the mask_token as part of the vocabulary size when output_mode='int'. This means that, if you have a max_tokens value of 5000, your output will have 5000 unique values (not 5001 as before). + * Change the return value of `TextVectorization.get_vocabulary()` from `byte` to `string`. Users who previously were calling 'decode' on the output of this method should no longer need to do so. + * Introduce new Keras dataset generation utilities : + * **[`image_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image_dataset_from_directory)** is a utility based on `tf.data.Dataset`, meant to replace the legacy `ImageDataGenerator`. It takes you from a structured directory of images to a labeled dataset, in one function call. Note that it doesn't perform image data augmentation (which is meant to be done using preprocessing layers). + * **[`text_dataset_from_directory`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/text_dataset_from_directory)** takes you from a structured directory of text files to a labeled dataset, in one function call. + * **[`timeseries_dataset_from_array`](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/timeseries_dataset_from_array)** is a `tf.data.Dataset`-based replacement of the legacy `TimeseriesGenerator`. It takes you from an array of timeseries data to a dataset of shifting windows with their targets. + * Added [`experimental_steps_per_execution`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/Model?hl=en#compile) + arg to `model.compile` to indicate the number of batches to run per `tf.function` call. This can speed up Keras Models on TPUs up to 3x. + * Extends `tf.keras.layers.Lambda` layers to support multi-argument lambdas, and keyword arguments when calling the layer. + * Functional models now get constructed if *any* tensor in a layer call's arguments/keyword arguments comes from a keras input. Previously the functional api would only work if all of the elements in the first argument to the layer came from a keras input. + * Clean up `BatchNormalization` layer's `trainable` property to act like standard python state when it's used inside `tf.functions` (frozen at tracing time), instead of acting like a pseudo-variable whose updates *kind of sometimes* get reflected in already-traced `tf.function` traces. + * Add the `Conv1DTranspose` layer. + * Refine the semantics of `SensitivitySpecificityBase` derived metrics. See the updated API docstrings for [`tf.keras.metrics.SensitivityAtSpecificity`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/metrics/SensitivityAtSpecificity) and [`tf.keras.metrics.SpecificityAtSensitivty`](https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/metrics/SpecificityAtSensitivity). + +### `tf.lite`: + * Converter + * Restored `inference_input_type` and `inference_output_type` flags in TF 2.x TFLiteConverter (backward compatible with TF 1.x) to support integer (tf.int8, tf.uint8) input and output types in post training full integer quantized models. + * Added support for converting and resizing models with dynamic (placeholder) dimensions. Previously, there was only limited support for dynamic batch size, and even that did not guarantee that the model could be properly resized at runtime. + * Enabled experimental support for a new quantization mode with 16-bit activations and 8-bit weights. See `lite.OpsSet.EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8`. + * CPU + * Fix an issue w/ dynamic weights and `Conv2D` on x86. + * Add a runtime Android flag for enabling `XNNPACK` for optimized CPU performance. + * Add a runtime iOS flag for enabling `XNNPACK` for optimized CPU performance. + * Add a compiler flag to enable building a TFLite library that applies `XNNPACK` delegate automatically when the model has a `fp32` operation. + * GPU + * Allow GPU acceleration starting with internal graph nodes + * Experimental support for quantized models with the Android GPU delegate + * Add GPU delegate whitelist. + * Rename GPU whitelist -> compatibility (list). + * Improve GPU compatibility list entries from crash reports. + * NNAPI + * Set default value for `StatefulNnApiDelegate::Options::max_number_delegated_partitions` to 3. + * Add capability to disable `NNAPI` CPU and check `NNAPI` Errno. + * Fix crashes when using `NNAPI` with target accelerator specified with model containing Conv2d or FullyConnected or LSTM nodes with quantized weights. + * Fix `ANEURALNETWORKS_BAD_DATA` execution failures with `sum`/`max`/`min`/`reduce` operations with `scalar` inputs. + * Hexagon + * TFLite Hexagon Delegate out of experimental. + * Experimental `int8` support for most hexagon ops. + * Experimental per-channel quant support for `conv` in Hexagon delegate. + * Support dynamic batch size in C++ API. + * CoreML + * Opensource CoreML delegate + * Misc + * Enable building Android TFLite targets on Windows + * Add support for `BatchMatMul`. + * Add support for `half_pixel_centers` with `ResizeNearestNeighbor`. + * Add 3D support for `BatchToSpaceND`. + * Add 5D support for `BroadcastSub`, `Maximum`, `Minimum`, `Transpose` and `BroadcastDiv`. + * Rename `kTfLiteActRelu1` to `kTfLiteActReluN1To1`. + * Enable flex delegate on tensorflow.lite.Interpreter Python package. + * Add `Buckettize`, `SparseCross` and `BoostedTreesBucketize` to the flex whitelist. + * Add support for selective registration of flex ops. + * Add missing kernels for flex delegate whitelisted ops. + * Fix issue when using direct `ByteBuffer` inputs with graphs that have dynamic shapes. + * Fix error checking supported operations in a model containing `HardSwish`. + +### Packaging Support + * Added `tf.sysconfig.get_build_info()`. Returns a dict that describes the build environment of the currently installed TensorFlow package, e.g. the NVIDIA CUDA and NVIDIA CuDNN versions used when TensorFlow was built. + +### Profiler + * Fix a subtle use-after-free issue in `XStatVisitor::RefValue()`. + +### TPU Enhancements + * Adds 3D mesh support in TPU configurations ops. + * Added TPU code for `FTRL` with `multiply_linear_by_lr`. + * Silently adds a new file system registry at `gstpu`. + * Support `restartType` in cloud tpu client. + * Depend on a specific version of google-api-python-client. + * Fixes apiclient import. + +### Tracing and Debugging + * Add a `TFE_Py_Execute` traceme. + +### XLA Support + * Implement stable `argmin` and `argmax` + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +902449@58880@bigcat_chen@ASIC, Abdul Baseer Khan, Abhineet Choudhary, Abolfazl Shahbazi, Adam Hillier, ag.ramesh, Agoniii, Ajay P, Alex Hoffman, Alexander Bayandin, Alexander Grund, Alexandre Abadie, Alexey Rogachevskiy, amoitra, Andrew Stevens, Angus-Luo, Anshuman Tripathy, Anush Elangovan, Artem Mavrin, Ashutosh Hathidara, autoih, Ayushman Kumar, ayushmankumar7, Bairen Yi, Bas Aarts, Bastian Eichenberger, Ben Barsdell, bhack, Bharat Raghunathan, Biagio Montaruli, Bigcat-Himax, blueyi, Bryan Cutler, Byambaa, Carlos Hernandez-Vaquero, Chen Lei, Chris Knorowski, Christian Clauss, chuanqiw, CuiYifeng, Daniel Situnayake, Daria Zhuravleva, Dayananda-V, Deven Desai, Devi Sandeep Endluri, Dmitry Zakharov, Dominic Jack, Duncan Riach, Edgar Liberis, Ehsan Toosi, ekuznetsov139, Elena Zhelezina, Eugene Kuznetsov, Eugene Mikhantiev, Evgenii Zheltonozhskii, Fabio Di Domenico, Fausto Morales, Fei Sun, feihugis, Felix E. Klee, flyingcat, Frederic Bastien, Fredrik Knutsson, frreiss, fsx950223, ganler, Gaurav Singh, Georgios Pinitas, Gian Marco Iodice, Giorgio Arena, Giuseppe Rossini, Gregory Keith, Guozhong Zhuang, gurushantj, Hahn Anselm, Harald Husum, Harjyot Bagga, Hristo Vrigazov, Ilya Persky, Ir1d, Itamar Turner-Trauring, jacco, Jake Tae, Janosh Riebesell, Jason Zaman, jayanth, Jeff Daily, Jens Elofsson, Jinzhe Zeng, JLZ, Jonas Skog, Jonathan Dekhtiar, Josh Meyer, Joshua Chia, Judd, justkw, Kaixi Hou, Kam D Kasravi, Kamil Rakoczy, Karol Gugala, Kayou, Kazuaki Ishizaki, Keith Smiley, Khaled Besrour, Kilaru Yasaswi Sri Chandra Gandhi, Kim, Young Soo, Kristian Hartikainen, Kwabena W. Agyeman, Leslie-Fang, Leslie-Fang-Intel, Li, Guizi, Lukas Geiger, Lutz Roeder, M\U00E5Ns Nilsson, Mahmoud Abuzaina, Manish, Marcel Koester, Marcin Sielski, marload, Martin Jul, Matt Conley, mdfaijul, Meng, Peng, Meteorix, Michael Käufl, Michael137, Milan Straka, Mitchell Vitez, Ml-0, Mokke Meguru, Mshr-H, nammbash, Nathan Luehr, naumkin, Neeraj Bhadani, ngc92, Nick Morgan, nihui, Niranjan Hasabnis, Niranjan Yadla, Nishidha Panpaliya, Oceania2018, oclyke, Ouyang Jin, OverLordGoldDragon, Owen Lyke, Patrick Hemmer, Paul Andrey, Peng Sun, periannath, Phil Pearl, Prashant Dandriyal, Prashant Kumar, Rahul Huilgol, Rajan Singh, Rajeshwar Reddy T, rangjiaheng, Rishit Dagli, Rohan Reddy, rpalakkal, rposts, Ruan Kunliang, Rushabh Vasani, Ryohei Ikegami, Semun Lee, Seo-Inyoung, Sergey Mironov, Sharada Shiddibhavi, ShengYang1, Shraiysh Vaishay, Shunya Ueta, shwetaoj, Siyavash Najafzade, Srinivasan Narayanamoorthy, Stephan Uphoff, storypku, sunchenggen, sunway513, Sven-Hendrik Haase, Swapnil Parekh, Tamas Bela Feher, Teng Lu, tigertang, tomas, Tomohiro Ubukata, tongxuan.ltx, Tony Tonev, Tzu-Wei Huang, Téo Bouvard, Uday Bondhugula, Vaibhav Jade, Vijay Tadikamalla, Vikram Dattu, Vincent Abriou, Vishnuvardhan Janapati, Vo Van Nghia, VoVAllen, Will Battel, William D. Irons, wyzhao, Xiaoming (Jason) Cui, Xiaoquan Kong, Xinan Jiang, xutianming, Yair Ehrenwald, Yasir Modak, Yasuhiro Matsumoto, Yixing Fu, Yong Tang, Yuan Tang, zhaozheng09, Zilin Zhu, zilinzhu, 张志豪 # Release 2.1.1 @@ -210,7 +487,7 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https: `Strategy.extended.update` and `Strategy.extended.update_non_slot`. * Experimental support for shape invariants has been enabled in `tf.function`. See the API docs for - `tf.autograph.experimental.set_loop_options` for additonal info. + `tf.autograph.experimental.set_loop_options` for additional info. * AutoGraph error messages now exclude frames corresponding to APIs internal to AutoGraph. * Improve shape inference for `tf.function` input arguments to unlock more @@ -293,7 +570,7 @@ Coinciding with this change, new releases of [TensorFlow's Docker images](https: also deterministic back-prop of bias-addition in Keras layers) to include when XLA JIT compilation is enabled. * Fix problem, when running on a CUDA GPU and when either environment - variable `TF_DETERMINSTIC_OPS` or environment variable + variable `TF_DETERMINISTIC_OPS` or environment variable `TF_CUDNN_DETERMINISTIC` is set to "true" or "1", in which some layer configurations led to an exception with the message "No algorithm worked!" @@ -336,32 +613,86 @@ This release contains contributions from many people at Google, as well as: TensorFlow 2.1 will be the last TF release supporting Python 2. Python 2 support [officially ends an January 1, 2020](https://www.python.org/dev/peps/pep-0373/#update). [As announced earlier](https://groups.google.com/a/tensorflow.org/d/msg/announce/gVwS5RC8mds/dCt1ka2XAAAJ), TensorFlow will also stop supporting Python 2 starting January 1, 2020, and no more releases are expected in 2019. ## Major Features and Improvements -* The `tensorflow` pip package now includes GPU support by default (same as `tensorflow-gpu`) for both Linux and Windows. This runs on machines with and without NVIDIA GPUs. `tensorflow-gpu` is still available, and CPU-only packages can be downloaded at `tensorflow-cpu` for users who are concerned about package size. -* **Windows users:** Officially-released `tensorflow` Pip packages are now built with Visual Studio 2019 version 16.4 in order to take advantage of the new `/d2ReducedOptimizeHugeFunctions` compiler flag. To use these new packages, you must install "Microsoft Visual C++ Redistributable for Visual Studio 2015, 2017 and 2019", available from Microsoft's website [here](https://support.microsoft.com/help/2977003/the-latest-supported-visual-c-downloads). - * This does not change the minimum required version for building TensorFlow from source on Windows, but builds enabling `EIGEN_STRONG_INLINE` can take over 48 hours to compile without this flag. Refer to `configure.py` for more information about `EIGEN_STRONG_INLINE` and `/d2ReducedOptimizeHugeFunctions`. - * If either of the required DLLs, `msvcp140.dll` (old) or `msvcp140_1.dll` (new), are missing on your machine, `import tensorflow` will print a warning message. -* The `tensorflow` pip package is built with CUDA 10.1 and cuDNN 7.6. -* `tf.keras` - * Experimental support for mixed precision is available on GPUs and Cloud TPUs. See [usage guide](https://www.tensorflow.org/guide/keras/mixed_precision). - * Introduced the `TextVectorization` layer, which takes as input raw strings and takes care of text standardization, tokenization, n-gram generation, and vocabulary indexing. See this [end-to-end text classification example](https://colab.research.google.com/drive/1RvCnR7h0_l4Ekn5vINWToI9TNJdpUZB3). - * Keras `.compile` `.fit` `.evaluate` and `.predict` are allowed to be outside of the DistributionStrategy scope, as long as the model was constructed inside of a scope. - * Experimental support for Keras `.compile`, `.fit`, `.evaluate`, and `.predict` is available for Cloud TPUs, Cloud TPU, for all types of Keras models (sequential, functional and subclassing models). - * Automatic outside compilation is now enabled for Cloud TPUs. This allows `tf.summary` to be used more conveniently with Cloud TPUs. - * Dynamic batch sizes with DistributionStrategy and Keras are supported on Cloud TPUs. - * Support for `.fit`, `.evaluate`, `.predict` on TPU using numpy data, in addition to `tf.data.Dataset`. - * Keras reference implementations for many popular models are available in the TensorFlow [Model Garden](https://github.com/tensorflow/models/tree/master/official). -* `tf.data` - * Changes rebatching for `tf.data datasets` + DistributionStrategy for better performance. Note that the dataset also behaves slightly differently, in that the rebatched dataset cardinality will always be a multiple of the number of replicas. - * `tf.data.Dataset` now supports automatic data distribution and sharding in distributed environments, including on TPU pods. - * Distribution policies for `tf.data.Dataset` can now be tuned with 1. `tf.data.experimental.AutoShardPolicy(OFF, AUTO, FILE, DATA)` 2. `tf.data.experimental.ExternalStatePolicy(WARN, IGNORE, FAIL)` -* `tf.debugging` - * Add `tf.debugging.enable_check_numerics()` and `tf.debugging.disable_check_numerics()` to help debugging the root causes of issues involving infinities and `NaN`s. -* `tf.distribute` - * Custom training loop support on TPUs and TPU pods is avaiable through `strategy.experimental_distribute_dataset`, `strategy.experimental_distribute_datasets_from_function`, `strategy.experimental_run_v2`, `strategy.reduce`. - * Support for a global distribution strategy through `tf.distribute.experimental_set_strategy(),` in addition to `strategy.scope()`. -* `TensorRT` - * [TensorRT 6.0](https://developer.nvidia.com/tensorrt#tensorrt-whats-new) is now supported and enabled by default. This adds support for more TensorFlow ops including Conv3D, Conv3DBackpropInputV2, AvgPool3D, MaxPool3D, ResizeBilinear, and ResizeNearestNeighbor. In addition, the TensorFlow-TensorRT python conversion API is exported as `tf.experimental.tensorrt.Converter`. -* Environment variable `TF_DETERMINISTIC_OPS` has been added. When set to "true" or "1", this environment variable makes `tf.nn.bias_add` operate deterministically (i.e. reproducibly), but currently only when XLA JIT compilation is *not* enabled. Setting `TF_DETERMINISTIC_OPS` to "true" or "1" also makes cuDNN convolution and max-pooling operate deterministically. This makes Keras Conv\*D and MaxPool\*D layers operate deterministically in both the forward and backward directions when running on a CUDA-enabled GPU. + +* The `tensorflow` pip package now includes GPU support by default (same as + `tensorflow-gpu`) for both Linux and Windows. This runs on machines with and + without NVIDIA GPUs. `tensorflow-gpu` is still available, and CPU-only + packages can be downloaded at `tensorflow-cpu` for users who are concerned + about package size. +* **Windows users:** Officially-released `tensorflow` Pip packages are now + built with Visual Studio 2019 version 16.4 in order to take advantage of the + new `/d2ReducedOptimizeHugeFunctions` compiler flag. To use these new + packages, you must install "Microsoft Visual C++ Redistributable for Visual + Studio 2015, 2017 and 2019", available from Microsoft's website + [here](https://support.microsoft.com/help/2977003/the-latest-supported-visual-c-downloads). + * This does not change the minimum required version for building + TensorFlow from source on Windows, but builds enabling + `EIGEN_STRONG_INLINE` can take over 48 hours to compile without this + flag. Refer to `configure.py` for more information about + `EIGEN_STRONG_INLINE` and `/d2ReducedOptimizeHugeFunctions`. + * If either of the required DLLs, `msvcp140.dll` (old) or `msvcp140_1.dll` + (new), are missing on your machine, `import tensorflow` will print a + warning message. +* The `tensorflow` pip package is built with CUDA 10.1 and cuDNN 7.6. +* `tf.keras` + * Experimental support for mixed precision is available on GPUs and Cloud + TPUs. See + [usage guide](https://www.tensorflow.org/guide/keras/mixed_precision). + * Introduced the `TextVectorization` layer, which takes as input raw + strings and takes care of text standardization, tokenization, n-gram + generation, and vocabulary indexing. See this + [end-to-end text classification example](https://colab.research.google.com/drive/1RvCnR7h0_l4Ekn5vINWToI9TNJdpUZB3). + * Keras `.compile` `.fit` `.evaluate` and `.predict` are allowed to be + outside of the DistributionStrategy scope, as long as the model was + constructed inside of a scope. + * Experimental support for Keras `.compile`, `.fit`, `.evaluate`, and + `.predict` is available for Cloud TPUs, Cloud TPU, for all types of + Keras models (sequential, functional and subclassing models). + * Automatic outside compilation is now enabled for Cloud TPUs. This allows + `tf.summary` to be used more conveniently with Cloud TPUs. + * Dynamic batch sizes with DistributionStrategy and Keras are supported on + Cloud TPUs. + * Support for `.fit`, `.evaluate`, `.predict` on TPU using numpy data, in + addition to `tf.data.Dataset`. + * Keras reference implementations for many popular models are available in + the TensorFlow + [Model Garden](https://github.com/tensorflow/models/tree/master/official). +* `tf.data` + * Changes rebatching for `tf.data datasets` + DistributionStrategy for + better performance. Note that the dataset also behaves slightly + differently, in that the rebatched dataset cardinality will always be a + multiple of the number of replicas. + * `tf.data.Dataset` now supports automatic data distribution and sharding + in distributed environments, including on TPU pods. + * Distribution policies for `tf.data.Dataset` can now be tuned with 1. + `tf.data.experimental.AutoShardPolicy(OFF, AUTO, FILE, DATA)` 2. + `tf.data.experimental.ExternalStatePolicy(WARN, IGNORE, FAIL)` +* `tf.debugging` + * Add `tf.debugging.enable_check_numerics()` and + `tf.debugging.disable_check_numerics()` to help debugging the root + causes of issues involving infinities and `NaN`s. +* `tf.distribute` + * Custom training loop support on TPUs and TPU pods is available through + `strategy.experimental_distribute_dataset`, + `strategy.experimental_distribute_datasets_from_function`, + `strategy.experimental_run_v2`, `strategy.reduce`. + * Support for a global distribution strategy through + `tf.distribute.experimental_set_strategy(),` in addition to + `strategy.scope()`. +* `TensorRT` + * [TensorRT 6.0](https://developer.nvidia.com/tensorrt#tensorrt-whats-new) + is now supported and enabled by default. This adds support for more + TensorFlow ops including Conv3D, Conv3DBackpropInputV2, AvgPool3D, + MaxPool3D, ResizeBilinear, and ResizeNearestNeighbor. In addition, the + TensorFlow-TensorRT python conversion API is exported as + `tf.experimental.tensorrt.Converter`. +* Environment variable `TF_DETERMINISTIC_OPS` has been added. When set to + "true" or "1", this environment variable makes `tf.nn.bias_add` operate + deterministically (i.e. reproducibly), but currently only when XLA JIT + compilation is *not* enabled. Setting `TF_DETERMINISTIC_OPS` to "true" or + "1" also makes cuDNN convolution and max-pooling operate deterministically. + This makes Keras Conv\*D and MaxPool\*D layers operate deterministically in + both the forward and backward directions when running on a CUDA-enabled GPU. ## Breaking Changes * Deletes `Operation.traceback_with_start_lines` for which we know of no usages. diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 8a0918b416f..d1c1d7dcdef 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -260,6 +260,36 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "armeabi", + values = {"cpu": "armeabi"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "armeabi-v7a", + values = {"cpu": "armeabi-v7a"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "arm64-v8a", + values = {"cpu": "arm64-v8a"}, + visibility = ["//visibility:public"], +) + +selects.config_setting_group( + name = "arm_any", + match_any = [ + ":arm", + ":armeabi", + ":armeabi-v7a", + ":arm64-v8a", + ":linux_aarch64", + ":linux_armhf", + ], +) + config_setting( name = "freebsd", values = {"cpu": "freebsd"}, @@ -532,16 +562,14 @@ selects.config_setting_group( package_group( name = "internal", packages = [ - # To pass open source testing in the pip Kokoros. - "//bazel_pip/tensorflow/...", "//learning/brain/swift/x10/...", "//perftools/accelerators/xprof/api/...", - "//third_party/py/autograph/...", - "//third_party/swift/tensorflow/x10/...", - "//third_party/swift/tensorflow_apis/...", "//tensorflow/...", "//tensorflow_estimator/python/estimator/...", "//tensorflow_models/official/...", + "//third_party/py/autograph/...", + "//third_party/swift/tensorflow/x10/...", + "//third_party/swift/tensorflow_apis/...", ], ) diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index f0f977aa0b5..5932dda514d 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -137,7 +137,7 @@ if _running_from_pip_package(): # TODO(gunan): Add sanity checks to loaded modules here. for _s in _site_packages_dirs: # Load first party dynamic kernels. - _main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels') + _main_dir = _os.path.join(_s, 'tensorflow/core/kernels') if _fi.file_exists(_main_dir): _ll.load_library(_main_dir) @@ -158,4 +158,23 @@ if hasattr(_current_module, 'keras'): setattr(_current_module, "initializers", initializers) # pylint: enable=undefined-variable +# Delete modules that should be hidden from dir(). +# Don't fail if these modules are not available. +# For e.g. this file will be originally placed under tensorflow/_api/v1 which +# does not have 'python', 'core' directories. Then, it will be copied +# to tensorflow/ which does have these two directories. +# pylint: disable=undefined-variable +try: + del python +except NameError: + pass +try: + del core +except NameError: + pass +try: + del compiler +except NameError: + pass + # __all__ PLACEHOLDER diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index dad91f2d5b2..0d1d2e56fae 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -147,7 +147,7 @@ if _running_from_pip_package(): # TODO(gunan): Add sanity checks to loaded modules here. for _s in _site_packages_dirs: # Load first party dynamic kernels. - _main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels') + _main_dir = _os.path.join(_s, 'tensorflow/core/kernels') if _fi.file_exists(_main_dir): _ll.load_library(_main_dir) @@ -156,4 +156,25 @@ if _running_from_pip_package(): if _fi.file_exists(_plugin_dir): _ll.load_library(_plugin_dir) +# Delete modules that should be hidden from dir(). +# Don't fail if these modules are not available. +# For e.g. this file will be originally placed under tensorflow/_api/v1 which +# does not have 'python', 'core' directories. Then, it will be copied +# to tensorflow/ which does have these two directories. + +# pylint: disable=undefined-variable +try: + del python +except NameError: + pass +try: + del core +except NameError: + pass +try: + del compiler +except NameError: + pass + + # __all__ PLACEHOLDER diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 410fc22069f..e5efe323922 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -213,6 +213,17 @@ tf_cuda_library( alwayslink = 1, ) +cc_library( + name = "logging", + srcs = ["logging.cc"], + hdrs = ["logging.h"], + deps = [ + ":c_api_macros", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:stringprintf", + ], +) + tf_cuda_library( name = "tf_status_internal", hdrs = [ diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 36a08c8cfc9..2e1759ecea0 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -213,7 +213,6 @@ void TF_Reset(const TF_SessionOptions* opt, const char** containers, namespace tensorflow { - Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, TF_Buffer* out) { if (out->data != nullptr) { @@ -306,8 +305,8 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, } // Helpers for loading a TensorFlow plugin (a .so file). -Status LoadLibrary(const char* library_filename, void** result, - const void** buf, size_t* len); +Status LoadDynamicLibrary(const char* library_filename, void** result, + const void** buf, size_t* len); // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and @@ -552,7 +551,7 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle, TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) { TF_Library* lib_handle = new TF_Library; - status->status = tensorflow::LoadLibrary( + status->status = tensorflow::LoadDynamicLibrary( library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data, &lib_handle->op_list.length); if (!status->status.ok()) { diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 808bcf3bd80..0b4d9993e4d 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -125,6 +125,14 @@ TF_CAPI_EXPORT extern void TF_DeleteBuffer(TF_Buffer*); TF_CAPI_EXPORT extern TF_Buffer TF_GetBuffer(TF_Buffer* buffer); +// -------------------------------------------------------------------------- +// Used to return strings across the C API. The caller does not take ownership +// of the underlying data pointer and is not responsible for freeing it. +typedef struct TF_StringView { + const char* data; + size_t len; +} TF_StringView; + // -------------------------------------------------------------------------- // TF_SessionOptions holds options that can be passed during session creation. typedef struct TF_SessionOptions TF_SessionOptions; diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index 831c6a0ad40..b4297033b6d 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" +#include "tensorflow/core/framework/collective.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -525,12 +526,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def, LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer( std::move(new_server), grpc_server->worker_env()->device_mgr, - grpc_server->worker_env()->collective_executor_mgr)); + grpc_server->worker_env()->collective_executor_mgr.get())); } else { LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def)); LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer( /*new_server=*/nullptr, grpc_server->worker_env()->device_mgr, - grpc_server->worker_env()->collective_executor_mgr)); + grpc_server->worker_env()->collective_executor_mgr.get())); } return tensorflow::Status::OK(); #undef LOG_AND_RETURN_IF_ERROR @@ -551,6 +552,14 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, status->status = EnableCollectiveOps(server_def, ctx); } +TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, + TF_Status* status) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + auto collective_executor_handle = context->GetCollectiveExecutorHandle(); + collective_executor_handle->get()->StartAbort(status->status); +} + TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) { TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList; result->num_items = num_items; diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index d0ffbf125fb..ebd14b4b571 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -230,6 +230,14 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx, size_t proto_len, TF_Status* status); +// Aborts all ongoing collectives with the specified status. After abortion, +// subsequent collectives will error with this status immediately. +// +// This is intended to be used when a peer failure is detected. There's yet no +// way to reset the collectives other than restarting the program. +TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx, + TF_Status* status); + // Information about the shape of a Tensor and its type. struct TF_ShapeAndType { // Number of dimensions. -1 indicates unknown rank. diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index a77e76644b8..61701bc8b21 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -240,6 +240,8 @@ tf_cuda_cc_test( "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/gradients:math_grad", + "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/cc/profiler", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/core:lib", @@ -260,6 +262,7 @@ cc_library( ], deps = [ "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:refcount", ], ) @@ -308,6 +311,8 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/util:abstract_stack_trace", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -514,7 +519,6 @@ tf_cuda_cc_test( extra_copts = tfe_xla_copts(), tags = [ "no_windows", - "noasan", # leaks gRPC server instances ], deps = [ ":c_api", @@ -581,7 +585,6 @@ tf_cuda_cc_test( extra_copts = tfe_xla_copts(), tags = [ "no_windows", - "noasan", # leaks gRPC server instances ], deps = [ ":c_api", diff --git a/tensorflow/c/eager/abstract_tensor_handle.h b/tensorflow/c/eager/abstract_tensor_handle.h index de041690420..37e6d1bf29c 100644 --- a/tensorflow/c/eager/abstract_tensor_handle.h +++ b/tensorflow/c/eager/abstract_tensor_handle.h @@ -18,11 +18,12 @@ limitations under the License. #include #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/refcount.h" namespace tensorflow { // Abstract interface to a Tensor handle in either tracing or immediate // execution mode. -class AbstractTensorHandle { +class AbstractTensorHandle : public core::RefCounted { protected: enum AbstractTensorHandleKind { kGraph, kMlir, kEager, kTfrt }; explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {} @@ -34,14 +35,6 @@ class AbstractTensorHandle { AbstractTensorHandleKind getKind() const { return kind_; } - // Release any underlying resources, including the interface object. - // - // WARNING: The destructor of this class is marked as protected to disallow - // clients from directly destroying this object since it may manage it's own - // lifetime through ref counting. Thus this must be allocated on the heap and - // clients MUST call Release() in order to destroy an instance of this class. - virtual void Release() = 0; - private: const AbstractTensorHandleKind kind_; }; @@ -50,7 +43,7 @@ namespace internal { struct AbstractTensorHandleDeleter { void operator()(AbstractTensorHandle* p) const { if (p != nullptr) { - p->Release(); + p->Unref(); } } }; diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 70acd710166..fefa753c608 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -94,7 +94,6 @@ limitations under the License. #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/public/version.h" -using tensorflow::int64; using tensorflow::string; namespace { @@ -725,13 +724,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { if (opts->use_tfrt) { #ifdef PLATFORM_GOOGLE - tfrt::SmallVector op_handler_chains; - tfrt::SmallVector device_attributes; - status->status = tfrt::ListOpHandlerChains( - opts->session_options.options, &op_handler_chains, &device_attributes); - if (!status->status.ok()) return nullptr; - return tensorflow::wrap(new tfrt::ContextInterface( - op_handler_chains, device_attributes, opts->async)); + return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async)); #else status->status = tensorflow::errors::Unimplemented("TFRT is not supported"); return nullptr; @@ -974,7 +967,7 @@ int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) { return -1; } - int64 num_elements = -1; + tensorflow::int64 num_elements = -1; status->status = tensorflow::unwrap(h)->NumElements(&num_elements); return num_elements; } @@ -986,7 +979,7 @@ int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, return -1; } - int64 dim = -1; + tensorflow::int64 dim = -1; status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim); return dim; } @@ -1079,11 +1072,13 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory( status->status = context->FindDeviceFromName(device_name, &device); tensorflow::CustomDevice* custom_device = nullptr; if (!status->status.ok()) { - status->status = - context->FindCustomDeviceFromName(device_name, &custom_device); - if (!status->status.ok()) { + if (!context->FindCustomDeviceFromName(device_name, &custom_device)) { deallocator(data, len, deallocator_arg); + status->status = + tensorflow::errors::InvalidArgument(device_name, " unknown device."); return nullptr; + } else { + status->status = tensorflow::Status::OK(); } } std::vector dimvec(num_dims); diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index 6827021455b..dd55f05283b 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -26,14 +26,13 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #endif // TENSORFLOW_EAGER_USE_XLA -using tensorflow::int64; using tensorflow::string; namespace { -std::vector TensorShapeAsVector(const tensorflow::TensorHandle& handle, - tensorflow::Status* status) { - std::vector shape; +std::vector TensorShapeAsVector( + const tensorflow::TensorHandle& handle, tensorflow::Status* status) { + std::vector shape; int rank = -1; *status = handle.NumDims(&rank); if (!status->ok()) { @@ -79,7 +78,7 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( return nullptr; } if (VLOG_IS_ON(3)) { - std::vector shape_to_log = + std::vector shape_to_log = TensorShapeAsVector(*handle, &status->status); if (!status->status.ok()) { // Ignore the status here as we are simply logging. @@ -128,14 +127,14 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( } int rank = padded_shape.dimensions_size(); - std::vector dev_dims; + std::vector dev_dims; dev_dims.reserve(rank); if (rank == 1) { // Rank 1 tensors might not have padded_shape.layout.minor_to_major set, dev_dims.push_back(padded_shape.dimensions(0)); } else { for (int i = rank - 1; i >= 0; --i) { - int64 dim_index = padded_shape.layout().minor_to_major(i); + tensorflow::int64 dim_index = padded_shape.layout().minor_to_major(i); dev_dims.push_back(padded_shape.dimensions(dim_index)); } } @@ -146,7 +145,8 @@ TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( // If the tensor is not an XLA tensor, the device shape is // the same as regular tensor shape. - std::vector dev_dims = TensorShapeAsVector(*handle, &status->status); + std::vector dev_dims = + TensorShapeAsVector(*handle, &status->status); if (!status->status.ok()) { return nullptr; } diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index a6547e23454..3738768cf02 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include // NOLINT + #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" @@ -174,9 +176,9 @@ void TestFunctionWithPackedInput(const bool remote) { const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; // Create one variable per task. - TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name); - TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name); - TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name); + TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task1_name); + TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task2_name); + TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task0_name); // Add a sync point in order to make sure that variables have been initialized // before the function execution starts. @@ -185,6 +187,9 @@ void TestFunctionWithPackedInput(const bool remote) { VarIsInitialized(ctx, h2); // Pack 3 variable handles into one TFE_TensorHandle. + // When remote is false, function device is placed on task0. Handle types are + // REMOTE, REMOTE, LOCAL on task0. When remote is true, function device is + // placed on task1, Handle types are LOCAL, REMOTE, LOCAL on task1. int num_replicas = 3; std::vector handles = {h0, h1, h2}; TFE_TensorHandle* packed_handle = @@ -259,61 +264,64 @@ TEST(CAPI, TestRemoteFunctionWithPackedInput) { TestFunctionWithPackedInput(/*remote=*/true); } +string VariableAddFunctionSignature() { + return " signature {" + " name: 'VariableAddFunction'" + " input_arg {" + " name: 'var0'" + " type: DT_RESOURCE" + " }" + " output_arg {" + " name: 'var0_value'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'read0'" + " op: 'ReadVariableOp'" + " input: 'var0'" + " attr {" + " key: 'dtype'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'add'" + " op: 'Add'" + " input: 'read0:value:0'" + " input: 'read0:value:0'" + " device: '/job:localhost/task:1/device:CPU:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " node_def {" + " name: 'identity'" + " op: 'Identity'" + " input: 'add:z:0'" + " device: '/job:localhost/task:0/device:CPU:0'" + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'var0_value'" + " value: 'identity:output:0'" + " }"; +} + string VariableAddFunction() { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( - " signature {" - " name: 'VariableAddFunction'" - " input_arg {" - " name: 'var0'" - " type: DT_RESOURCE" - " }" - " output_arg {" - " name: 'var0_value'" - " type: DT_FLOAT" - " }" - " }" - " node_def {" - " name: 'read0'" - " op: 'ReadVariableOp'" - " input: 'var0'" - " attr {" - " key: 'dtype'" - " value {" - " type: DT_FLOAT" - " }" - " }" - " }" - " node_def {" - " name: 'add'" - " op: 'Add'" - " input: 'read0:value:0'" - " input: 'read0:value:0'" - " device: '/job:localhost/task:1/device:CPU:0'" - " attr {" - " key: 'T'" - " value {" - " type: DT_FLOAT" - " }" - " }" - " }" - " node_def {" - " name: 'identity'" - " op: 'Identity'" - " input: 'add:z:0'" - " device: '/job:localhost/task:0/device:CPU:0'" - " attr {" - " key: 'T'" - " value {" - " type: DT_FLOAT" - " }" - " }" - " }" - " ret {" - " key: 'var0_value'" - " value: 'identity:output:0'" - " }", - &def)); + VariableAddFunctionSignature(), &def)); return def.SerializeAsString(); } @@ -425,6 +433,17 @@ TEST(CAPI, DistributedFunctionGraphPassOnlyOnce) { GraphErrorInjectionPass::enabled_ = false; } +string VariableAddFunctionWithGraphError() { + string signature = VariableAddFunctionSignature(); + // Replace the node 'read0' with 'read0_maybe_with_graph_error', so that the + // error injecting pass can identify and introduce graph pass errors. + signature = std::regex_replace(signature, std::regex("read0"), + "read0_maybe_with_graph_error"); + tensorflow::FunctionDef def; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString(signature, &def)); + return def.SerializeAsString(); +} + class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { public: FunctionErrorInjectionPass(string error_node, string error_device) @@ -471,16 +490,19 @@ void TestDistributedFunctionCancellation(bool inject_error) { const char dev2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0"; if (inject_error) { - // Inject a function optimization pass failure when it sees the 'read0' op - // having a requested device `dev2_name`. During execution: - // * task:0 processes the main function `VariableAddFunction` and places - // the read0 op on task:2 - // * task:0 partitions the main function with a subgraph containing read0 - // sent to task:2 - // * task:2 graph pass reports an error when it sees read0 with dev2_name + // Inject a function optimization pass failure when it sees the + // 'read0_maybe_with_graph_error' op having a requested device `dev2_name`. + // During execution: + // * task:0 processes main function `VariableAddFunctionWithGraphError` + // and places the 'read0_maybe_with_graph_error' op on task:2 + // * task:0 partitions the main function with a subgraph containing + // 'read0_maybe_with_graph_error' sent to task:2 + // * task:2 graph pass reports an error when it sees + // 'read0_maybe_with_graph_error' with dev2_name tensorflow::function_optimization_registration:: FunctionOptimizationPassRegistration register_test_pass( - std::make_unique("read0", dev2_name)); + std::make_unique( + "read0_maybe_with_graph_error", dev2_name)); } TF_Status* status = TF_NewStatus(); @@ -496,7 +518,7 @@ void TestDistributedFunctionCancellation(bool inject_error) { TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name); EXPECT_NE(var_handle, nullptr); - const string function_def = VariableAddFunction(); + const string function_def = VariableAddFunctionWithGraphError(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); diff --git a/tensorflow/c/eager/c_api_remote_test.cc b/tensorflow/c/eager/c_api_remote_test.cc index 94c32cf3f30..e99f6d6e170 100644 --- a/tensorflow/c/eager/c_api_remote_test.cc +++ b/tensorflow/c/eager/c_api_remote_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/str_cat.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" @@ -115,40 +116,42 @@ void TestRemoteExecute(bool async) { TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } -string MatMulFunction() { +string MatMulFunction(const string& matmul_device) { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( - " signature {" - " name: 'MatMulFunction'" - " input_arg {" - " name: 'a'" - " type: DT_FLOAT" - " }" - " input_arg {" - " name: 'b'" - " type: DT_FLOAT" - " }" - " output_arg {" - " name: 'm'" - " type: DT_FLOAT" - " }" - " }" - " node_def {" - " name: 'matmul'" - " op: 'MatMul'" - " input: 'a'" - " input: 'b'" - " attr {" - " key: 'T'" - " value {" - " type: DT_FLOAT" - " }" - " }" - " }" - " ret {" - " key: 'm'" - " value: 'matmul:product'" - " }", + absl::StrCat(" signature {" + " name: 'MatMulFunction'" + " input_arg {" + " name: 'a'" + " type: DT_FLOAT" + " }" + " input_arg {" + " name: 'b'" + " type: DT_FLOAT" + " }" + " output_arg {" + " name: 'm'" + " type: DT_FLOAT" + " }" + " }" + " node_def {" + " name: 'matmul'" + " op: 'MatMul'" + " input: 'a'" + " input: 'b'" + " device: '", + matmul_device, "'", + " attr {" + " key: 'T'" + " value {" + " type: DT_FLOAT" + " }" + " }" + " }" + " ret {" + " key: 'm'" + " value: 'matmul:product'" + " }"), &def)); return def.SerializeAsString(); } @@ -157,7 +160,8 @@ string MatMulFunction() { // which creates a remote remote input, to simulate a scenario that the remote // input is not ready when we start running an op or a function. void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, - bool heavy_load_on_streaming_rpc) { + bool heavy_load_on_streaming_rpc, + bool remote_func_outputs = false) { tensorflow::ServerDef server_def = GetServerDef(3); // This server def has the task index set to 0. @@ -214,7 +218,8 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, TFE_Op* matmul = nullptr; if (func) { - string function_def = MatMulFunction(); + const string matmul_device = remote_func_outputs ? task2_name : ""; + string function_def = MatMulFunction(matmul_device); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), status); CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); @@ -250,7 +255,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); // TODO(gjn): Add support for waiting on async local mirrors - if (!remote && !async) { + if (!remote && !async && !remote_func_outputs) { auto remote_arg = tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2)); // The input handles should never change since they have been mirrored. @@ -329,6 +334,19 @@ TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) { TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true, /*heavy_load_on_streaming_rpc=*/false); } +// TODO(b/162618595): Enable this test once we remove the check of remote +// outputs in ProcessFunctionLibraryRuntime. +TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) { + TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false, + /*func=*/true, + /*heavy_load_on_streaming_rpc=*/false, + /*remote_func_outputs=*/true); +} +TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) { + TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true, + /*heavy_load_on_streaming_rpc=*/false, + /*remote_func_outputs=*/true); +} TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) { // A remote input may be not ready when we start running a function. Test that // the function execution should wait until the remote input is ready. diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 4b5ad8f50f7..192f10533a6 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -88,6 +88,20 @@ TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) { return th; } +TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx, + float data[], int64_t dims[], + int num_dims) { + TF_Status* status = TF_NewStatus(); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) { constexpr int64_t dims[] = {100, 100}; constexpr int num_elements = dims[0] * dims[1]; @@ -143,7 +157,7 @@ TFE_TensorHandle* TestVariable(TFE_Context* ctx, float value, if (TF_GetCode(status) != TF_OK) return nullptr; TFE_OpSetAttrType(op, "dtype", TF_FLOAT); TFE_OpSetAttrShape(op, "shape", {}, 0, status); - TFE_OpSetAttrString(op, "container", "", 0); + TFE_OpSetAttrString(op, "container", "localhost", 0); TFE_OpSetAttrString(op, "shared_name", "", 0); if (!device_name.empty()) { TFE_OpSetDevice(op, device_name.c_str(), status); diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index fcf62223f14..fcf407aa9c3 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -34,6 +34,12 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx); // Return a tensor handle containing a 2x2 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx); +// Return a tensor handle containing 2D matrix containing given data and +// dimensions +TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx, + float data[], int64_t dims[], + int num_dims); + // Return a tensor handle containing a 100x100 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx); diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 605a60c186c..8408f7ef60f 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -147,7 +147,7 @@ TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) { void TF_DeleteAbstractOp(TF_AbstractOp* op) { unwrap(op)->Release(); } -void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Release(); } +void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { unwrap(t)->Unref(); } TF_OutputList* TF_NewOutputList() { return wrap(new OutputList); } void TF_DeleteOutputList(TF_OutputList* o) { delete unwrap(o); } diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 6165a7d14a3..7bda3aed76d 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -33,6 +33,7 @@ limitations under the License. using tensorflow::dyn_cast; using tensorflow::string; +using tensorflow::gtl::ArraySlice; namespace tensorflow { namespace tracing { @@ -48,7 +49,6 @@ class GraphTensor : public TracingTensorHandle { public: explicit GraphTensor(TF_Output output) : TracingTensorHandle(kGraph), output_(output) {} - void Release() override { delete this; } tensorflow::DataType DataType() const override { return static_cast(TF_OperationOutputType(output_)); @@ -138,20 +138,23 @@ class GraphOperation : public TracingOperation { Status SetAttrString(const char* attr_name, const char* data, size_t length) override { - return tensorflow::errors::Unimplemented( - "SetAttrString has not been implemented yet."); + tensorflow::StringPiece s(data, length); + op_->node_builder.Attr(attr_name, s); + return Status::OK(); } Status SetAttrInt(const char* attr_name, int64_t value) override { - return tensorflow::errors::Unimplemented( - "SetAttrInt has not been implemented yet."); + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + op_->node_builder.Attr(attr_name, static_cast(value)); + return Status::OK(); } Status SetAttrFloat(const char* attr_name, float value) override { - return tensorflow::errors::Unimplemented( - "SetAttrFloat has not been implemented yet."); + op_->node_builder.Attr(attr_name, value); + return Status::OK(); } Status SetAttrBool(const char* attr_name, bool value) override { - return tensorflow::errors::Unimplemented( - "SetAttrBool has not been implemented yet."); + op_->node_builder.Attr(attr_name, value); + return Status::OK(); } Status SetAttrType(const char* const attr_name, DataType value) override { if (!op_) { @@ -164,8 +167,15 @@ class GraphOperation : public TracingOperation { } Status SetAttrShape(const char* attr_name, const int64_t* dims, const int num_dims) override { - return tensorflow::errors::Unimplemented( - "SetAttrShape has not been implemented yet."); + PartialTensorShape shape; + if (num_dims >= 0) { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + shape = PartialTensorShape(ArraySlice( + reinterpret_cast(dims), num_dims)); + } + op_->node_builder.Attr(attr_name, shape); + return Status::OK(); } Status SetAttrFunction(const char* attr_name, const AbstractOperation* value) override { @@ -174,8 +184,10 @@ class GraphOperation : public TracingOperation { } Status SetAttrFunctionName(const char* attr_name, const char* value, size_t length) override { - return tensorflow::errors::Unimplemented( - "SetAttrFunctionName has not been implemented yet."); + tensorflow::NameAttrList func_name; + func_name.set_name(string(value, value + length)); + op_->node_builder.Attr(attr_name, func_name); + return Status::OK(); } Status SetAttrTensor(const char* attr_name, AbstractTensorInterface* tensor) override { @@ -184,33 +196,71 @@ class GraphOperation : public TracingOperation { } Status SetAttrStringList(const char* attr_name, const void* const* values, const size_t* lengths, int num_values) override { - return tensorflow::errors::Unimplemented( - "SetAttrStringList has not been implemented yet."); + if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { + op_->colocation_constraints.clear(); + for (int i = 0; i < num_values; ++i) { + op_->colocation_constraints.emplace(static_cast(values[i]), + lengths[i]); + } + } else { + std::vector v; + v.reserve(num_values); + for (int i = 0; i < num_values; ++i) { + v.emplace_back(static_cast(values[i]), lengths[i]); + } + op_->node_builder.Attr(attr_name, v); + } + return Status::OK(); } Status SetAttrFloatList(const char* attr_name, const float* values, int num_values) override { - return tensorflow::errors::Unimplemented( - "SetAttrFloatList has not been implemented yet."); + op_->node_builder.Attr(attr_name, + ArraySlice(values, num_values)); + return Status::OK(); } Status SetAttrIntList(const char* attr_name, const int64_t* values, int num_values) override { - return tensorflow::errors::Unimplemented( - "SetAttrIntList has not been implemented yet."); + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + op_->node_builder.Attr( + attr_name, + ArraySlice( + reinterpret_cast(values), num_values)); + return Status::OK(); } Status SetAttrTypeList(const char* attr_name, const DataType* values, int num_values) override { - return tensorflow::errors::Unimplemented( - "SetAttrTypeList has not been implemented yet."); + op_->node_builder.Attr(attr_name, + ArraySlice(values, num_values)); + return Status::OK(); } Status SetAttrBoolList(const char* attr_name, const unsigned char* values, int num_values) override { - return tensorflow::errors::Unimplemented( - "SetAttrBoolList has not been implemented yet."); + std::unique_ptr b(new bool[num_values]); + for (int i = 0; i < num_values; ++i) { + b[i] = values[i]; + } + op_->node_builder.Attr(attr_name, + ArraySlice(b.get(), num_values)); + + return Status::OK(); } Status SetAttrShapeList(const char* attr_name, const int64_t** dims, const int* num_dims, int num_values) override { - return tensorflow::errors::Unimplemented( - "SetAttrShapeList has not been implemented yet."); + std::vector shapes; + shapes.reserve(num_values); + for (int i = 0; i < num_values; ++i) { + if (num_dims[i] < 0) { + shapes.emplace_back(); + } else { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + shapes.emplace_back(ArraySlice( + reinterpret_cast(dims[i]), num_dims[i])); + } + } + op_->node_builder.Attr(attr_name, shapes); + return Status::OK(); } Status SetAttrFunctionList( const char* attr_name, diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index a25dccc4638..c56e8ab05fc 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -92,9 +92,255 @@ TEST_P(UnifiedCAPI, TestBasicEager) { TF_DeleteExecutionContext(ctx); } +// MatMul Test +TEST_P(UnifiedCAPI, TestBasicEagerMatMul) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + /* Want to test simple MatMul example: + [[0,0], * [[0,0], = [[0,0], + [0,0]] [0,0]] [0,0]] + */ + + // Build an abstract input tensor. + int64_t dims[] = {2, 2}; // Matrices will be 2 x 2 + int num_dims = sizeof(dims) / sizeof(dims[0]); + + float vals[] = {0.0f, 0.0f, 0.0f, 0.0f}; + TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get()); + TFE_TensorHandle* t = + TestMatrixTensorHandleWithInput(eager_ctx, vals, dims, num_dims); + + TF_AbstractTensor* at = TF_CreateAbstractTensorFromEagerTensor( + t, status.get()); // get abstract tensor + + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an abstract operation. + auto* op = TF_NewAbstractOp(ctx); + TF_AbstractOpSetOpType(op, "MatMul", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build inputs and outputs. + TF_AbstractTensor* inputs[2] = {at, at}; + TF_OutputList* o = TF_NewOutputList(); + TF_OutputListSetNumOutputs(o, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Execute. + TF_ExecuteOperation(op, 2, inputs, o, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Clean up operation and inputs. + TF_DeleteAbstractOp(op); + TF_DeleteAbstractTensor(at); + + // Verify the results. + ASSERT_EQ(1, TF_OutputListNumOutputs(o)); + TF_AbstractTensor* result = TF_OutputListGet(o, 0); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(result, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + + // Copy Tensor data into an array. + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(result_tensor), + TF_TensorByteSize(result_tensor)); + + int data_len = 4; // length of result_data + for (int i = 0; i < data_len; i++) { + EXPECT_EQ(result_data[i], 0); + } + + TF_DeleteTensor(result_tensor); + TF_DeleteAbstractTensor(result); + TF_DeleteOutputList(o); + TF_DeleteExecutionContext(ctx); +} + +// MatMul Test 2 +TEST_P(UnifiedCAPI, TestBasicEagerMatMul2) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + /* Want to test simple MatMul example with abstract tensors: + [[1,2], * [[5,6], = [[19,22], + [3,4]] [7,8]] [43,50]] + */ + + // Build 1st Matrix. + int64_t dims[] = {2, 2}; // Matrices will be 2 x 2 + int num_dims = sizeof(dims) / sizeof(dims[0]); + + float vals1[] = {1.0f, 2.0f, 3.0f, 4.0f}; + TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get()); + TFE_TensorHandle* t1 = + TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims); + + TF_AbstractTensor* at1 = TF_CreateAbstractTensorFromEagerTensor( + t1, status.get()); // get abstract tensor + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build 2nd Matrix. + float vals2[] = {5.0f, 6.0f, 7.0f, 8.0f}; + TFE_TensorHandle* t2 = + TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims); + + TF_AbstractTensor* at2 = TF_CreateAbstractTensorFromEagerTensor( + t2, status.get()); // get abstract tensor + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an abstract operation. + auto* op = TF_NewAbstractOp(ctx); + TF_AbstractOpSetOpType(op, "MatMul", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build inputs and outputs. + TF_AbstractTensor* inputs[2] = {at1, at2}; + TF_OutputList* o = TF_NewOutputList(); + TF_OutputListSetNumOutputs(o, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Execute. + TF_ExecuteOperation(op, 2, inputs, o, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Clean up operation and inputs. + TF_DeleteAbstractOp(op); + TF_DeleteAbstractTensor(at1); + TF_DeleteAbstractTensor(at2); + + // Verify the results. + ASSERT_EQ(1, TF_OutputListNumOutputs(o)); + TF_AbstractTensor* result = TF_OutputListGet(o, 0); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(result, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + + // Copy Tensor data into array. + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(result_tensor), + TF_TensorByteSize(result_tensor)); + + // Build expected result & verify. + float e_vals[] = {19.0f, 22.0f, 43.0f, 50.0f}; + + int data_len = 4; // length of e_vals + for (int i = 0; i < data_len; i++) { + EXPECT_EQ(result_data[i], e_vals[i]); + } + + TF_DeleteTensor(result_tensor); + TF_DeleteAbstractTensor(result); + TF_DeleteOutputList(o); + TF_DeleteExecutionContext(ctx); +} + +// MatAdd +TEST_P(UnifiedCAPI, TestBasicEagerMatAdd) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + /* Want to test simple MatAdd example with abstract tensors: + [[1,2] , + [[5,6], = [[6,8], + [3,4] ] [7,8] ] [10,12]] + */ + + // Build 1st Matrix. + int64_t dims[] = {2, 2}; // Matrices will be 2 x 2 + int num_dims = sizeof(dims) / sizeof(dims[0]); + + float vals1[] = {1.0f, 2.0f, 3.0f, 4.0f}; + TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get()); + TFE_TensorHandle* t1 = + TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims); + + TF_AbstractTensor* at1 = TF_CreateAbstractTensorFromEagerTensor( + t1, status.get()); // get abstract tensor + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build 2nd Matrix. + float vals2[] = {5.0f, 6.0f, 7.0f, 8.0f}; + TFE_TensorHandle* t2 = + TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims); + + TF_AbstractTensor* at2 = TF_CreateAbstractTensorFromEagerTensor( + t2, status.get()); // get abstract tensor + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an abstract operation. + auto* op = TF_NewAbstractOp(ctx); + TF_AbstractOpSetOpType(op, "Add", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build inputs and outputs. + TF_AbstractTensor* inputs[2] = {at1, at2}; + TF_OutputList* o = TF_NewOutputList(); + TF_OutputListSetNumOutputs(o, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Execute. + TF_ExecuteOperation(op, 2, inputs, o, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Clean up operation and inputs. + TF_DeleteAbstractOp(op); + TF_DeleteAbstractTensor(at1); + TF_DeleteAbstractTensor(at2); + + // Verify the results. + ASSERT_EQ(1, TF_OutputListNumOutputs(o)); + TF_AbstractTensor* result = TF_OutputListGet(o, 0); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(result, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + + // Copy Tensor data into array. + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(result_tensor), + TF_TensorByteSize(result_tensor)); + + // Build expected result & verify. + float e_vals[] = {6.0f, 8.0f, 10.0f, 12.0f}; + + int data_len = 4; // length of e_vals + for (int i = 0; i < data_len; i++) { + EXPECT_EQ(result_data[i], e_vals[i]); + } + + TF_DeleteTensor(result_tensor); + TF_DeleteAbstractTensor(result); + TF_DeleteOutputList(o); + TF_DeleteExecutionContext(ctx); +} + TEST_P(UnifiedCAPI, TestBasicGraph) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); + // Start a new function / execution context. string fn_name = "double"; TF_ExecutionContext* graph_ctx = @@ -142,6 +388,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) { TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + // Build the abstract op to run the function. TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx); TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get()); @@ -180,6 +427,111 @@ TEST_P(UnifiedCAPI, TestBasicGraph) { TF_DeleteExecutionContext(eager_execution_ctx); } +// Graph Tracing for MatMul +TEST_P(UnifiedCAPI, TestBasicGraphMatMul) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + // Start a new function / execution context. + string fn_name = "matrix_multiply"; + TF_ExecutionContext* graph_ctx = + TF_CreateFunction(fn_name.c_str(), status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + auto* placeholder_t = + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an abstract operation. + auto* matmul_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(matmul_op, "MatMul", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_AbstractOpSetOpName(matmul_op, "my_matmul", status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build inputs and outputs. + TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t}; + TF_OutputList* mm_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(mm_outputs, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Execute. + TF_ExecuteOperation(matmul_op, 2, inputs, mm_outputs, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Clean up operation and inputs. + TF_DeleteAbstractOp(matmul_op); + + TF_AbstractFunction* func = + TF_FinalizeFunction(graph_ctx, mm_outputs, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + /* Now that the graph is built, test graph implementation on matmul example: + [[1,1] , * [[1,1] , = [[2,2], + [1,1]] [1,1]] [2,2]] + */ + + // Build eager context. + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* eager_execution_ctx = + TF_NewEagerExecutionContext(opts, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TFE_DeleteContextOptions(opts); + + TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build the abstract op to run the function. + TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx); + TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Build an abstract input tensor. + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get()); + + float vals[] = {1.0f, 1.0f, 1.0f, 1.0f}; + int64_t dims[] = {2, 2}; // Matrices will be 2 x 2 + int num_dims = sizeof(dims) / sizeof(dims[0]); + + TFE_TensorHandle* input_eager = + TestMatrixTensorHandleWithInput(eager_ctx, vals, dims, num_dims); + TF_AbstractTensor* input_t = + TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_OutputListSetNumOutputs(mm_outputs, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_ExecuteOperation(fn_op, 1, &input_t, mm_outputs, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + ASSERT_EQ(1, TF_OutputListNumOutputs(mm_outputs)); + TF_AbstractTensor* final_result = TF_OutputListGet(mm_outputs, 0); + TFE_TensorHandle* final = + TF_AbstractTensorGetEagerTensor(final_result, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + TF_Tensor* f_t = TFE_TensorHandleResolve(final, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(f_t), TF_TensorByteSize(f_t)); + + int data_len = 4; + for (int i = 0; i < data_len; i++) { + ASSERT_EQ(result_data[i], 2.0f); + } + + TF_DeleteAbstractTensor(final_result); + TF_DeleteOutputList(mm_outputs); + TF_DeleteAbstractTensor(placeholder_t); + TF_DeleteAbstractOp(fn_op); + TF_DeleteAbstractTensor(input_t); + TF_DeleteTensor(f_t); + TF_DeleteAbstractFunction(func); + + TF_DeleteExecutionContext(eager_execution_ctx); +} + TEST_P(UnifiedCAPI, TestMultiOutputGraph) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -336,6 +688,217 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { TF_DeleteAbstractFunction(func); } +TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Status* s = status.get(); + + // Start a new function / execution context. + string fn_name = "two_adds_and_matmul"; + TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Create a first "Add" computing `arg0 + arg1`. + TF_AbstractTensor* add_output1; + { + // Build an abstract operation, inputs and output. + auto* add_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(add_op, "Add", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractOpSetOpName(add_op, "my_add1", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractTensor* inputs[2] = {arg0, arg1}; + TF_OutputList* add_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(add_outputs, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Trace the operation now (create a node in the graph). + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(add_op); + + // Extract the resulting tensor. + add_output1 = TF_OutputListGet(add_outputs, 0); + TF_DeleteOutputList(add_outputs); + } + + // Same with a second "Add" computing `arg1 + arg1`. + TF_AbstractTensor* add_output2; + { + // Build an abstract operation, inputs and output. + auto* add_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(add_op, "Add", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractOpSetOpName(add_op, "my_add2", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractTensor* inputs[2] = {arg1, arg1}; + TF_OutputList* add_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(add_outputs, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Trace the operation now (create a node in the graph). + TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(add_op); + + // Extract the resulting tensor. + add_output2 = TF_OutputListGet(add_outputs, 0); + TF_DeleteOutputList(add_outputs); + } + + // 3rd Output will be Matrix Multiplication of add_output1 and add_output2 + TF_AbstractTensor* mm_output; + { + // Build an abstract operation, inputs and output. + auto* mm_op = TF_NewAbstractOp(graph_ctx); + TF_AbstractOpSetOpType(mm_op, "MatMul", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractOpSetOpName(mm_op, "mm", s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_AbstractTensor* inputs[2] = {add_output1, add_output2}; + TF_OutputList* mm_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(mm_outputs, 1, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Trace the operation now (create a node in the graph). + TF_ExecuteOperation(mm_op, 2, inputs, mm_outputs, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(mm_op); + + // Extract the resulting tensor. + mm_output = TF_OutputListGet(mm_outputs, 0); + TF_DeleteOutputList(mm_outputs); + } + + // Finalize the function by providing the returned values. + TF_AbstractFunction* func; + { + // We want to return the output of both add operations and MatMul operation, + // create a new list and populate it. + TF_OutputList* func_outputs = TF_NewOutputList(); + TF_OutputListPushBack(func_outputs, add_output1, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_OutputListPushBack(func_outputs, add_output2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_OutputListPushBack(func_outputs, mm_output, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + func = TF_FinalizeFunction(graph_ctx, func_outputs, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteOutputList(func_outputs); + } + + /** + * We traced so far this function: + * + * def two_adds_and_mm(A, B): + * my_add1 = A + B + * my_add2 = B + B + * mm = tf.MatMul(my_add1,my_add2) + * return my_add1, my_add2, mm + * + * Now we will execute this function with an eager context: + * + * A =[[0, 1],[1, 0]] + * B =[[1, 0],[0, 1]] + * + * output1, output2, output3 = two_adds_and_mm(A, B) + * + * We expect outputs: + * + * output1 = [[1, 1],[1, 1]] + * output2 = [[2, 0],[0, 2]] + * output3 = [[2, 2],[2, 2]] + * + */ + + // Build eager context. + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TF_ExecutionContext* eager_execution_ctx = + TF_NewEagerExecutionContext(opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TFE_DeleteContextOptions(opts); + + TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Build the abstract op to run the function. + TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx); + TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Build two abstract input tensors as function arguments. + std::vector func_args; + { + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(eager_execution_ctx, s); + + // 1st Arg + float vals1[] = {0.0f, 1.0f, 1.0f, 0.0f}; + int64_t dims[] = {2, 2}; // Matrices will be 2 x 2 + int num_dims = sizeof(dims) / sizeof(dims[0]); + + TFE_TensorHandle* input_eager = + TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims); + func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // 2nd Arg + float vals2[] = {1.0f, 0.0f, 0.0f, 1.0f}; + input_eager = + TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims); + func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s)); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + } + + TF_OutputList* func_outputs = TF_NewOutputList(); + TF_OutputListSetNumOutputs(func_outputs, 3, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs, + s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteAbstractOp(fn_op); + for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t); + + ASSERT_EQ(3, TF_OutputListNumOutputs(func_outputs)); + + float expected_outputs[3][4] = {{1.0f, 1.0f, 1.0f, 1.0f}, + {2.0f, 0.0f, 0.0f, 2.0f}, + {2.0f, 2.0f, 2.0f, 2.0f}}; + + float result_data[4]; + for (int idx = 0; idx < 3; ++idx) { + TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx); + TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + memcpy(&result_data[0], TF_TensorData(f_t), TF_TensorByteSize(f_t)); + + // Verify results for each output + for (int j = 0; j < 4; j++) { + ASSERT_EQ(result_data[j], expected_outputs[idx][j]); + } + + TF_DeleteTensor(f_t); + } + + // Free memory associated with add and MatMul outputs + for (int idx = 0; idx < 3; ++idx) { + TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx); + TF_DeleteAbstractTensor(result); + } + + TF_DeleteOutputList(func_outputs); + TF_DeleteExecutionContext(eager_execution_ctx); + TF_DeleteAbstractFunction(func); +} + TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 3a7a6282192..39cadd421e2 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -51,25 +51,14 @@ int64 ToId(AbstractTensorHandle* t) { TapeTensor::TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx) : handle_(handle), ctx_(ctx) { - // TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely - // on the client to keep this tensor live for the duration of the gradient - // computation. - // handle_->Ref(); + handle_->Ref(); } TapeTensor::TapeTensor(const TapeTensor& other) { handle_ = other.handle_; - // TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely - // on the client to keep this tensor live for the duration of the gradient - // computation. - // handle_->Ref(); + handle_->Ref(); ctx_ = other.ctx_; } -TapeTensor::~TapeTensor() { - // TODO(b/160888114): Make AbstractTensorHandle RefCounted. Right now we rely - // on the client to keep this tensor live for the duration of the gradient - // computation. - // handle_->Unref(); -} +TapeTensor::~TapeTensor() { handle_->Unref(); } tensorflow::int64 TapeTensor::GetID() const { return ToId(handle_); } @@ -112,7 +101,7 @@ AbstractTensorHandle* TapeTensor::ZerosLike() const { } if (isa(op.get())) { s = dyn_cast(op.get())->SetOpName( - absl::StrCat("OnesLike", ToId(handle_)).c_str()); + absl::StrCat("ZerosLike", ToId(handle_)).c_str()); if (!s.ok()) { return nullptr; } @@ -175,7 +164,8 @@ Status TapeVSpace::CallBackwardFunction( gtl::ArraySlice output_gradients, std::vector* result) const { if (backward_function == nullptr) return Status::OK(); - return backward_function->Compute(output_gradients, result); + Context ctx = {ctx_}; + return backward_function->Compute(&ctx, output_gradients, result); } // Looks up the ID of a Gradient. @@ -191,7 +181,7 @@ TapeTensor TapeVSpace::TapeTensorFromGradient(AbstractTensorHandle* g) const { void TapeVSpace::MarkAsResult(AbstractTensorHandle* gradient) const {} void TapeVSpace::DeleteGradient(AbstractTensorHandle* gradient) const { - gradient->Release(); + gradient->Unref(); } // Helper functions which delegate to `AbstractOperation`, update @@ -373,6 +363,10 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx, input_ids[i] = ToId(forward_op_->inputs[i]); input_dtypes[i] = forward_op_->inputs[i]->DataType(); } + for (int i = 0; i < *num_retvals; i++) { + // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. + forward_op_->outputs.push_back(retvals[i]); + } std::vector tape_tensors; for (auto t : retvals) { tape_tensors.push_back(TapeTensor(t, ctx)); diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h index e09b6ff8613..267ee5b7ab2 100644 --- a/tensorflow/c/eager/gradients.h +++ b/tensorflow/c/eager/gradients.h @@ -31,7 +31,8 @@ namespace gradients { // // class AddGradientFunction : public GradientFunction { // public: -// Status Compute(absl::Span grad_inputs, +// Status Compute(Context* ctx, +// absl::Span grad_inputs, // std::vector* grad_outputs) override { // grad_outputs->resize(2); // (*grad_outputs)[0] = grad_inputs[0]; @@ -50,11 +51,16 @@ namespace gradients { // Status RegisterGradients(GradientRegistry* registry) { // return registry->Register("Add", AddRegisterer); // } +struct Context { + public: + AbstractContext* ctx; +}; class GradientFunction { public: // TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in // `grad_inputs`. - virtual Status Compute(absl::Span grad_inputs, + virtual Status Compute(Context* ctx, + absl::Span grad_inputs, std::vector* grad_outputs) = 0; virtual ~GradientFunction() {} }; diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index 5820058f3e2..41993b3e125 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" @@ -42,55 +44,12 @@ class CppGradients } }; -// Creates an Identity op. -Status Identity(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, const char* name) { - AbstractOperationPtr identity_op(ctx->CreateOperation()); - TF_RETURN_IF_ERROR( - identity_op->Reset("Identity", /*raw_device_name=*/nullptr)); - if (isa(identity_op.get())) { - TF_RETURN_IF_ERROR(dyn_cast(identity_op.get()) - ->SetOpName(name)); - } - TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0])); - int num_retvals = 1; - TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals)); +Status RegisterGradients(GradientRegistry* registry) { + TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); return Status::OK(); } -// =================== Register gradients for Add ============================ -class AddGradientFunction : public GradientFunction { - public: - explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {} - Status Compute(absl::Span grad_inputs, - std::vector* grad_outputs) override { - grad_outputs->resize(2); - std::vector identity_outputs(1); - TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]}, - absl::MakeSpan(identity_outputs), "Id0")); - (*grad_outputs)[0] = identity_outputs[0]; - TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]}, - absl::MakeSpan(identity_outputs), "Id1")); - (*grad_outputs)[1] = identity_outputs[0]; - return Status::OK(); - } - ~AddGradientFunction() override {} - - private: - AbstractContext* ctx_; -}; - -GradientFunction* AddRegisterer(const ForwardOperation& op) { - return new AddGradientFunction(op.ctx); -} - -Status RegisterGradients(GradientRegistry* registry) { - return registry->Register("Add", AddRegisterer); -} - -// =================== End gradient registrations ============================ - // Computes `inputs[0] + inputs[1]` and records it on the tape. Status Add(AbstractContext* ctx, Tape* tape, absl::Span inputs, @@ -112,6 +71,26 @@ Status Add(AbstractContext* ctx, Tape* tape, registry); } +// Computes `exp(inputs[0])` and records it on the tape. +Status Exp(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractOperationPtr exp_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR( + Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op)); + if (isa(exp_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(exp_op.get())->SetOpName("my_exp")); + } + TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op)); + int num_retvals = 1; + return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + // Computes // y = inputs[0] + inputs[1] // return grad(y, {inputs[0], inputs[1]}) @@ -136,7 +115,7 @@ Status AddGradModel(AbstractContext* ctx, source_tensors_that_are_targets, /*output_gradients=*/{}, &out_grads)); for (auto add_output : add_outputs) { - add_output->Release(); + add_output->Unref(); } outputs[0] = out_grads[0]; outputs[1] = out_grads[1]; @@ -144,6 +123,35 @@ Status AddGradModel(AbstractContext* ctx, return Status::OK(); } +// Computes +// y = exp(inputs[0]) +// return grad(y, {inputs[0]}) +Status ExpGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + std::vector exp_outputs(1); + TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs), + registry)); // Compute x+y. + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads)); + for (auto exp_output : exp_outputs) { + exp_output->Unref(); + } + outputs[0] = out_grads[0]; + delete tape; + return Status::OK(); +} + AbstractContext* BuildFunction(const char* fn_name) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -187,14 +195,15 @@ Status RunModel(Model model, AbstractContext* ctx, TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), absl::MakeSpan(output_list.outputs), registry)); for (auto func_input : func_inputs) { - func_input->Release(); + func_input->Unref(); } AbstractFunction* func = nullptr; TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) ->Finalize(&output_list, &func)); scoped_func.reset(func); - output_list.outputs[0]->Release(); - output_list.outputs[1]->Release(); + for (auto output : output_list.outputs) { + output->Unref(); + } TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); } @@ -295,7 +304,7 @@ TEST_P(CppGradients, TestAddGrad) { ASSERT_EQ(errors::OK, s.code()) << s.error_message(); auto result_value = static_cast(TF_TensorData(result_tensor)); EXPECT_EQ(*result_value, 1.0); - outputs[0]->Release(); + outputs[0]->Unref(); TF_DeleteTensor(result_tensor); result_tensor = nullptr; @@ -303,17 +312,61 @@ TEST_P(CppGradients, TestAddGrad) { ASSERT_EQ(errors::OK, s.code()) << s.error_message(); result_value = static_cast(TF_TensorData(result_tensor)); EXPECT_EQ(*result_value, 1.0); - outputs[1]->Release(); + outputs[1]->Unref(); TF_DeleteTensor(result_tensor); } +TEST_P(CppGradients, TestExpGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Pseudo-code: + // + // tape.watch(x) + // y = exp(x) + // outputs = tape.gradient(y, x) + std::vector outputs(1); + s = RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* result_tensor; + s = getValue(outputs[0], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_NEAR(*result_value, 2.718, 0.001); + outputs[0]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; +} + // TODO(b/160888630): Enable this test with mlir after AddInputList is // supported. It is needed for AddN op which is used for gradient aggregation. #ifdef PLATFORM_GOOGLE INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, ::testing::Combine(::testing::Values("graphdef"), - /*tfrt*/ ::testing::Values(false), + /*tfrt*/ ::testing::Values(true, false), /*executing_eagerly*/ ::testing::Values(true, false))); #else INSTANTIATE_TEST_SUITE_P( diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h index 31a75c5b8c7..ee212b21a96 100644 --- a/tensorflow/c/eager/immediate_execution_operation.h +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" @@ -26,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/util/abstract_stack_trace.h" struct TFE_Op; @@ -36,6 +38,10 @@ class ImmediateExecutionOperation : public AbstractOperation { public: virtual void Clear() = 0; + // Returns the inputs of this op. + virtual absl::Span GetInputs() + const = 0; + virtual const tensorflow::OpDef* OpDef() const = 0; virtual Status InputLength(const char* input_name, int* length) = 0; @@ -44,6 +50,12 @@ class ImmediateExecutionOperation : public AbstractOperation { // Experimental virtual Status SetUseXla(bool enable) = 0; + // Set stack trace to be used for potential async error reporting. + virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0; + + // Returns the stack trace set by `SetStackTrace` if exists. + virtual absl::optional GetStackTrace() = 0; + // For LLVM style RTTI. static bool classof(const AbstractOperation* ptr) { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h index f7c77aa06db..6d32d482747 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -50,6 +50,14 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle { // Return a copy of the handle. virtual ImmediateExecutionTensorHandle* Copy() = 0; + // Release any underlying resources, including the interface object. + // + // WARNING: The destructor of this class is marked as protected to disallow + // clients from directly destroying this object since it may manage it's own + // lifetime through ref counting. Thus this must be allocated on the heap and + // clients MUST call Release() in order to destroy an instance of this class. + virtual void Release() = 0; + // For LLVM style RTTI. static bool classof(const AbstractTensorHandle* ptr) { return ptr->getKind() == kEager || ptr->getKind() == kTfrt; diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 40cfa87dd66..27629bb3bdf 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -177,12 +177,12 @@ class GradientTape { template class ForwardFunction : public std::function&, - std::vector*)> { + std::vector*, bool)> { public: template explicit ForwardFunction(lambda_type lambda) : std::function&, - std::vector*)>(lambda) {} + std::vector*, bool)>(lambda) {} }; // Computes Jacobian-vector products using forward-mode automatic @@ -205,8 +205,9 @@ class ForwardAccumulator { // Does not take ownership of `vspace`, which must outlive the // ForwardAccumulator. explicit ForwardAccumulator( - const VSpace& vspace) - : vspace_(vspace) { + const VSpace& vspace, + bool use_batch) + : vspace_(vspace), use_batch_(use_batch) { call_state_.emplace(nullptr, false); } @@ -314,6 +315,9 @@ class ForwardAccumulator { // available in language bindings (e.g. Python). const VSpace& vspace_; + // Decides if tangents are vectorized or not + bool use_batch_; + struct AccumulatorCallState { AccumulatorCallState( GradientTape* backward_tape, @@ -573,7 +577,7 @@ Status InitialGradients( gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, const OpTape& op_tape, std::unordered_map>* result) { - for (int i = 0; i < target_tensor_ids.size(); ++i) { + for (int i = 0, end = target_tensor_ids.size(); i < end; ++i) { const int64 id = target_tensor_ids[i]; if (output_gradients.empty() || output_gradients[i] == nullptr) { auto tensor_it = tensor_tape.find(id); @@ -699,7 +703,7 @@ Status GradientTape::ComputeGradient( std::vector out_gradients; out_gradients.reserve(trace.output_tensor_info.size()); std::vector unneeded_gradients; - for (int i = 0; i < trace.input_tensor_id.size(); i++) { + for (int i = 0, end = trace.input_tensor_id.size(); i < end; i++) { const auto& in_tensor_id = trace.input_tensor_id[i]; if (tensor_tape_.find(in_tensor_id) == tensor_tape_.end() && sources_set.find(in_tensor_id) == sources_set.end()) { @@ -709,7 +713,7 @@ Status GradientTape::ComputeGradient( bool any_gradient_nonzero = false; std::vector zero_indices; - for (int i = 0; i < trace.output_tensor_info.size(); ++i) { + for (int i = 0, end = trace.output_tensor_info.size(); i < end; ++i) { const int64 id = trace.output_tensor_info[i].GetID(); auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { @@ -775,7 +779,7 @@ Status GradientTape::ComputeGradient( } VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " << trace.input_tensor_id.size() << " sources"; - for (int i = 0; i < in_gradients.size(); ++i) { + for (int i = 0, end = in_gradients.size(); i < end; ++i) { const int64 id = trace.input_tensor_id[i]; if (in_gradients[i] != nullptr) { auto& unaggregated_grads = gradients[id]; @@ -968,7 +972,7 @@ ForwardAccumulator::ForwardpropFromTape( targets.reserve(grad.size()); used_in_grads.reserve(grad.size()); std::unordered_map sources_that_are_targets; - for (int grad_index = 0; grad_index < grad.size(); ++grad_index) { + for (int grad_index = 0, end = grad.size(); grad_index < end; ++grad_index) { Gradient* grad_tensor = grad[grad_index]; if (grad_tensor != nullptr) { int64 tensor_id = vspace_.TensorId(grad_tensor); @@ -1062,7 +1066,8 @@ Status ForwardAccumulator::Accumulate( output_tensors, backward_function_getter, backward_function_deleter, in_grads, &forward_grads)); } else { - TF_RETURN_IF_ERROR((*forward_function)(in_grads, &forward_grads)); + TF_RETURN_IF_ERROR( + (*forward_function)(in_grads, &forward_grads, use_batch_)); } for (int i = 0; i < forward_grads.size(); ++i) { if (forward_grads[i] != nullptr) { diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc index ce715c43acb..fbde13dea5a 100644 --- a/tensorflow/c/env.cc +++ b/tensorflow/c/env.cc @@ -186,3 +186,22 @@ void TF_JoinThread(TF_Thread* thread) { // ::tensorflow::Thread joins on destruction delete reinterpret_cast<::tensorflow::Thread*>(thread); } + +void* TF_LoadSharedLibrary(const char* library_filename, TF_Status* status) { + void* handle = nullptr; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->LoadDynamicLibrary(library_filename, + &handle)); + return handle; +} + +void* TF_GetSymbolFromLibrary(void* handle, const char* symbol_name, + TF_Status* status) { + void* symbol = nullptr; + TF_SetStatus(status, TF_OK, ""); + ::tensorflow::Set_TF_Status_from_Status( + status, ::tensorflow::Env::Default()->GetSymbolFromLibrary( + handle, symbol_name, &symbol)); + return symbol; +} diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h index 7dc7ac32f08..63e2c86ad44 100644 --- a/tensorflow/c/env.h +++ b/tensorflow/c/env.h @@ -184,6 +184,26 @@ TF_CAPI_EXPORT extern TF_Thread* TF_StartThread(const TF_ThreadOptions* options, // Waits for the given thread to finish execution, then deletes it. TF_CAPI_EXPORT extern void TF_JoinThread(TF_Thread* thread); +// \brief Load a dynamic library. +// +// Pass "library_filename" to a platform-specific mechanism for dynamically +// loading a library. The rules for determining the exact location of the +// library are platform-specific and are not documented here. +// +// On success, place OK in status and return the newly created library handle. +// Otherwise returns nullptr and set error status. +TF_CAPI_EXPORT extern void* TF_LoadSharedLibrary(const char* library_filename, + TF_Status* status); + +// \brief Get a pointer to a symbol from a dynamic library. +// +// "handle" should be a pointer returned from a previous call to +// TF_LoadLibraryFromEnv. On success, place OK in status and return a pointer to +// the located symbol. Otherwise returns nullptr and set error status. +TF_CAPI_EXPORT extern void* TF_GetSymbolFromLibrary(void* handle, + const char* symbol_name, + TF_Status* status); + #ifdef __cplusplus } #endif diff --git a/tensorflow/c/experimental/BUILD b/tensorflow/c/experimental/BUILD deleted file mode 100644 index 53cd99f18a6..00000000000 --- a/tensorflow/c/experimental/BUILD +++ /dev/null @@ -1,124 +0,0 @@ -# Description: -# Experimental C APIs for TensorFlow. - -load( - "//tensorflow:tensorflow.bzl", - "tf_copts", - "tf_cuda_library", -) -load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") - -package( - licenses = ["notice"], # Apache 2.0 -) - -tf_cuda_library( - name = "rendezvous_internal", - srcs = [ - "rendezvous.cc", - ], - hdrs = [ - "rendezvous.h", - "rendezvous_internal.h", - ], - copts = tf_copts(), - visibility = ["//tensorflow/c:__subpackages__"], - deps = [ - "//tensorflow/c:c_api_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", - "//tensorflow/core/distributed_runtime:worker_env", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - ], -) - -tf_cuda_library( - name = "rendezvous", - hdrs = [ - "rendezvous.h", - ], - copts = tf_copts(), - visibility = ["//visibility:public"], - deps = [ - ":rendezvous_internal", - "//tensorflow/c:c_api", - ], -) - -tf_cuda_library( - name = "network_internal", - srcs = [ - "network.cc", - ], - hdrs = [ - "network.h", - "network_internal.h", - ], - copts = tf_copts(), - visibility = ["//tensorflow/c:__subpackages__"], - deps = [ - ":rendezvous_internal", - "//tensorflow/c:c_api_internal", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/distributed_runtime:server_lib", - "//tensorflow/core/distributed_runtime:worker_env", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - ], -) - -tf_cuda_library( - name = "network", - hdrs = [ - "network.h", - ], - copts = tf_copts(), - visibility = ["//visibility:public"], - deps = [ - ":network_internal", - ":rendezvous", - "//tensorflow/c:c_api", - ], -) - -# ----------------------------------------------------------------------------- -# Tests - -tf_cuda_cc_test( - name = "network_test", - size = "medium", - srcs = ["network_test.cc"], - tags = ["noasan"], - # We must ensure that the dependencies can be dynamically linked since - # the shared library must be able to use core:framework. - # linkstatic = tf_kernel_tests_linkstatic(), - deps = [ - ":network", - ":network_internal", - ":rendezvous", - ":rendezvous_internal", - "//tensorflow/c:c_api", - "//tensorflow/c:env", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", - "//tensorflow/core/distributed_runtime:server_lib", - "//tensorflow/core/distributed_runtime:session_mgr", - "//tensorflow/core/distributed_runtime:worker_env", - "//tensorflow/core/distributed_runtime:worker_session", - "//tensorflow/core/distributed_runtime/rpc:async_service_interface", - "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - ], -) diff --git a/tensorflow/c/experimental/filesystem/filesystem_interface.h b/tensorflow/c/experimental/filesystem/filesystem_interface.h index 5463eb35088..6e05c861439 100644 --- a/tensorflow/c/experimental/filesystem/filesystem_interface.h +++ b/tensorflow/c/experimental/filesystem/filesystem_interface.h @@ -78,6 +78,11 @@ typedef struct TF_Filesystem { void* plugin_filesystem; } TF_Filesystem; +typedef struct TF_TransactionToken { + void* token; + TF_Filesystem* owner; +} TF_TransactionToken; + /// SECTION 2. Function tables for functionality provided by plugins /// ---------------------------------------------------------------------------- /// @@ -679,6 +684,133 @@ typedef struct TF_FilesystemOps { /// /// DEFAULT IMPLEMENTATION: No op. void (*flush_caches)(const TF_Filesystem* filesystem); + + /// Starts a new transaction. + /// + /// An opaque transaction token is returned in `token`. Ownership of the token + /// is in filesystem. Token will be freed in `end_transaction` call and any + /// access to token after that is invalid. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `token` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if transaction successfuly started. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if multiple transactions + /// are not supported + /// * Might use any other error value for `status` to signal other errors. + int (*start_transaction)(const TF_Filesystem* filesystem, + TF_TransactionToken** token, TF_Status* status); + + /// Ends transaction and free the `token`. Any access to token after + /// that will be invalid. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `token` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if transaction successfuly finalized. + /// * Must set `status` to `TF_NOT_FOUND` if token is invalid/not found + /// * Might use any other error value for `status` to signal other errors. + int (*end_transaction)(const TF_Filesystem* filesystem, + TF_TransactionToken* token, TF_Status* status); + + /// Adds file/directory in the `path` to transaction in `token`. It is a valid + /// operation to add a path that doesn't exist yet to a transaction. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `token` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if path added to transaction successful. + /// * Must set `status` to `TF_NOT_FOUND` if `token` is invalid. + /// * Must set `status` to `TF_FAILED_PRECONDITION` if file/directory is in + /// another transaction and multiple transactions are not supported + /// * Might use any other error value for `status` to signal other errors. + int (*add_to_transaction)(const TF_Filesystem* filesystem, const char* path, + TF_TransactionToken* token, TF_Status* status); + + /// Returns transaction token for file/directory in the `path`. Note that path + /// may not exist yet but still might be part of a transaction. + /// + /// Transaction token is returned in `token`. Ownership of the token is in + /// filesystem. Token will be freed in `end_transaction` call and any access + /// to token after that is invalid. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `token` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if a transaction for path is found + /// * Must set `status` to `TF_NOT_FOUND` if `path` is not part of any + /// transaction + /// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is + /// not in this filesystem. + /// * Might use any other error value for `status` to signal other errors. + int (*get_transaction_for_path)(const TF_Filesystem* filesystem, + const char* path, TF_TransactionToken** token, + TF_Status* status); + + /// Returns transaction token for `path` if it is part of a transaction else + /// starts a new transaction and adds `path` to that transaction + /// + /// Transaction token is returned in `token`. Ownership of the token is in + /// filesystem. Token will be freed in `end_transaction` call and any access + /// to token after that is invalid. + /// + /// In case of error, plugins must set `status` to a value different than + /// `TF_OK`, free memory allocated for `token` and return -1. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// Plugins: + /// * Must set `status` to `TF_OK` if transaction found or successfuly + /// started. + /// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to this + /// filesystem + /// * Must set `status` to `TF_FAILED_PRECONDITION` if file/directory is + /// not in any transaction and multiple transactions are not supported. + /// * Might use any other error value for `status` to signal other errors. + int (*get_or_start_transaction_for_path)(const TF_Filesystem* filesystem, + const char* path, + TF_TransactionToken** token, + TF_Status* status); + + /// Decodes transaction token in `token` to human readable format for + /// debugging. + /// + /// A new `char*` buffer must be allocated by this method. Core TensorFlow + /// manages the lifetime of the buffer after the call. Thus, all callers of + /// this method must take ownership of the returned pointer. + /// + /// Plugins must not return `nullptr`. Returning empty strings is allowed. + /// + /// The allocation and freeing of memory must happen via the functions sent to + /// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo` + /// structure in Section 4). + /// + /// DEFAULT IMPLEMENTATION: Dump token and owner address. + char* (*decode_transaction_token)(const TF_Filesystem* filesystem, + const TF_TransactionToken* token); + } TF_FilesystemOps; // LINT.ThenChange(:filesystem_ops_version) diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.cc b/tensorflow/c/experimental/filesystem/modular_filesystem.cc index 58541ea2b36..9c8d3518800 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.cc @@ -35,7 +35,8 @@ using UniquePtrTo_TF_Status = ::std::unique_ptr; Status ModularFileSystem::NewRandomAccessFile( - const std::string& fname, std::unique_ptr* result) { + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { if (ops_->new_random_access_file == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support NewRandomAccessFile()")); @@ -54,7 +55,8 @@ Status ModularFileSystem::NewRandomAccessFile( } Status ModularFileSystem::NewWritableFile( - const std::string& fname, std::unique_ptr* result) { + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { if (ops_->new_writable_file == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support NewWritableFile()")); @@ -73,7 +75,8 @@ Status ModularFileSystem::NewWritableFile( } Status ModularFileSystem::NewAppendableFile( - const std::string& fname, std::unique_ptr* result) { + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { if (ops_->new_appendable_file == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support NewAppendableFile()")); @@ -92,7 +95,8 @@ Status ModularFileSystem::NewAppendableFile( } Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile( - const std::string& fname, std::unique_ptr* result) { + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) { if (ops_->new_read_only_memory_region_from_file == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, @@ -112,7 +116,8 @@ Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile( return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::FileExists(const std::string& fname) { +Status ModularFileSystem::FileExists(const std::string& fname, + TransactionToken* token) { if (ops_->path_exists == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support FileExists()")); @@ -125,6 +130,7 @@ Status ModularFileSystem::FileExists(const std::string& fname) { } bool ModularFileSystem::FilesExist(const std::vector& files, + TransactionToken* token, std::vector* status) { if (ops_->paths_exist == nullptr) return FileSystem::FilesExist(files, status); @@ -157,6 +163,7 @@ bool ModularFileSystem::FilesExist(const std::vector& files, } Status ModularFileSystem::GetChildren(const std::string& dir, + TransactionToken* token, std::vector* result) { if (ops_->get_children == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( @@ -182,6 +189,7 @@ Status ModularFileSystem::GetChildren(const std::string& dir, } Status ModularFileSystem::GetMatchingPaths(const std::string& pattern, + TransactionToken* token, std::vector* result) { if (ops_->get_matching_paths == nullptr) return internal::GetMatchingPaths(this, Env::Default(), pattern, result); @@ -203,7 +211,8 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern, return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::DeleteFile(const std::string& fname) { +Status ModularFileSystem::DeleteFile(const std::string& fname, + TransactionToken* token) { if (ops_->delete_file == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support DeleteFile()")); @@ -216,6 +225,7 @@ Status ModularFileSystem::DeleteFile(const std::string& fname) { } Status ModularFileSystem::DeleteRecursively(const std::string& dirname, + TransactionToken* token, int64* undeleted_files, int64* undeleted_dirs) { if (undeleted_files == nullptr || undeleted_dirs == nullptr) @@ -238,7 +248,8 @@ Status ModularFileSystem::DeleteRecursively(const std::string& dirname, return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::DeleteDir(const std::string& dirname) { +Status ModularFileSystem::DeleteDir(const std::string& dirname, + TransactionToken* token) { if (ops_->delete_dir == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", dirname, " does not support DeleteDir()")); @@ -250,7 +261,8 @@ Status ModularFileSystem::DeleteDir(const std::string& dirname) { return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname) { +Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname, + TransactionToken* token) { if (ops_->recursively_create_dir == nullptr) return FileSystem::RecursivelyCreateDir(dirname); @@ -261,7 +273,8 @@ Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname) { return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::CreateDir(const std::string& dirname) { +Status ModularFileSystem::CreateDir(const std::string& dirname, + TransactionToken* token) { if (ops_->create_dir == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", dirname, " does not support CreateDir()")); @@ -273,7 +286,8 @@ Status ModularFileSystem::CreateDir(const std::string& dirname) { return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::Stat(const std::string& fname, FileStatistics* stat) { +Status ModularFileSystem::Stat(const std::string& fname, + TransactionToken* token, FileStatistics* stat) { if (ops_->stat == nullptr) return errors::Unimplemented(tensorflow::strings::StrCat( "Filesystem for ", fname, " does not support Stat()")); @@ -296,7 +310,8 @@ Status ModularFileSystem::Stat(const std::string& fname, FileStatistics* stat) { return StatusFromTF_Status(plugin_status.get()); } -Status ModularFileSystem::IsDirectory(const std::string& name) { +Status ModularFileSystem::IsDirectory(const std::string& name, + TransactionToken* token) { if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); @@ -307,6 +322,7 @@ Status ModularFileSystem::IsDirectory(const std::string& name) { } Status ModularFileSystem::GetFileSize(const std::string& fname, + TransactionToken* token, uint64* file_size) { if (ops_->get_file_size == nullptr) { FileStatistics stat; @@ -327,7 +343,8 @@ Status ModularFileSystem::GetFileSize(const std::string& fname, } Status ModularFileSystem::RenameFile(const std::string& src, - const std::string& target) { + const std::string& target, + TransactionToken* token) { if (ops_->rename_file == nullptr) { Status status = CopyFile(src, target); if (status.ok()) status = DeleteFile(src); @@ -343,7 +360,8 @@ Status ModularFileSystem::RenameFile(const std::string& src, } Status ModularFileSystem::CopyFile(const std::string& src, - const std::string& target) { + const std::string& target, + TransactionToken* token) { if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target); UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus); @@ -366,7 +384,7 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const { return ret; } -void ModularFileSystem::FlushCaches() { +void ModularFileSystem::FlushCaches(TransactionToken* token) { if (ops_->flush_caches != nullptr) ops_->flush_caches(filesystem_.get()); } @@ -443,7 +461,7 @@ Status RegisterFilesystemPlugin(const std::string& dso_path) { // Step 1: Load plugin Env* env = Env::Default(); void* dso_handle; - TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle)); + TF_RETURN_IF_ERROR(env->LoadDynamicLibrary(dso_path.c_str(), &dso_handle)); // Step 2: Load symbol for `TF_InitPlugin` void* dso_symbol; diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.h b/tensorflow/c/experimental/filesystem/modular_filesystem.h index baf665fd6aa..061a1aa446b 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.h +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.h @@ -59,36 +59,48 @@ class ModularFileSystem final : public FileSystem { ~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); } + TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; + Status NewRandomAccessFile( - const std::string& fname, + const std::string& fname, TransactionToken* token, std::unique_ptr* result) override; - Status NewWritableFile(const std::string& fname, + Status NewWritableFile(const std::string& fname, TransactionToken* token, std::unique_ptr* result) override; - Status NewAppendableFile(const std::string& fname, + Status NewAppendableFile(const std::string& fname, TransactionToken* token, std::unique_ptr* result) override; Status NewReadOnlyMemoryRegionFromFile( - const std::string& fname, + const std::string& fname, TransactionToken* token, std::unique_ptr* result) override; - Status FileExists(const std::string& fname) override; + Status FileExists(const std::string& fname, TransactionToken* token) override; bool FilesExist(const std::vector& files, + TransactionToken* token, std::vector* status) override; - Status GetChildren(const std::string& dir, + Status GetChildren(const std::string& dir, TransactionToken* token, std::vector* result) override; - Status GetMatchingPaths(const std::string& pattern, + Status GetMatchingPaths(const std::string& pattern, TransactionToken* token, std::vector* results) override; - Status DeleteFile(const std::string& fname) override; - Status DeleteRecursively(const std::string& dirname, int64* undeleted_files, + Status DeleteFile(const std::string& fname, TransactionToken* token) override; + Status DeleteRecursively(const std::string& dirname, TransactionToken* token, + int64* undeleted_files, int64* undeleted_dirs) override; - Status DeleteDir(const std::string& dirname) override; - Status RecursivelyCreateDir(const std::string& dirname) override; - Status CreateDir(const std::string& dirname) override; - Status Stat(const std::string& fname, FileStatistics* stat) override; - Status IsDirectory(const std::string& fname) override; - Status GetFileSize(const std::string& fname, uint64* file_size) override; - Status RenameFile(const std::string& src, const std::string& target) override; - Status CopyFile(const std::string& src, const std::string& target) override; + Status DeleteDir(const std::string& dirname, + TransactionToken* token) override; + Status RecursivelyCreateDir(const std::string& dirname, + TransactionToken* token) override; + Status CreateDir(const std::string& dirname, + TransactionToken* token) override; + Status Stat(const std::string& fname, TransactionToken* token, + FileStatistics* stat) override; + Status IsDirectory(const std::string& fname, + TransactionToken* token) override; + Status GetFileSize(const std::string& fname, TransactionToken* token, + uint64* file_size) override; + Status RenameFile(const std::string& src, const std::string& target, + TransactionToken* token) override; + Status CopyFile(const std::string& src, const std::string& target, + TransactionToken* token) override; std::string TranslateName(const std::string& name) const override; - void FlushCaches() override; + void FlushCaches(TransactionToken* token) override; private: std::unique_ptr filesystem_; diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc b/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc index 8ee47da01dd..7e0a95cc915 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem_test.cc @@ -33,7 +33,6 @@ limitations under the License. // Windows defines the following macros to convert foo to fooA or fooW, // depending on the type of the string argument. We don't use these macros, so // undefine them here. -#undef LoadLibrary #undef CopyFile #undef DeleteFile #undef TranslateName diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index a0c13701766..68875d61e47 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -25,12 +25,15 @@ cc_library( "//tensorflow:windows": get_win_copts(), }), deps = [ + ":expiring_lru_cache", ":gcs_helper", + ":ram_file_block_cache", "//tensorflow/c:env", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", ], ) @@ -44,14 +47,6 @@ cc_library( ], ) -cc_library( - name = "file_block_cache", - hdrs = ["file_block_cache.h"], - deps = [ - "//tensorflow/c:tf_status", - ], -) - cc_library( name = "cleanup", hdrs = ["cleanup.h"], @@ -63,7 +58,6 @@ cc_library( hdrs = ["ram_file_block_cache.h"], deps = [ ":cleanup", - ":file_block_cache", "//tensorflow/c:env", "//tensorflow/c:tf_status", "@com_google_absl//absl/base:core_headers", diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h deleted file mode 100644 index 3ba7d8d7993..00000000000 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_ -#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/c/tf_status.h" - -namespace tf_gcs_filesystem { - -class FileBlockCache; - -/// FileBlockCacheStatsInterface allows for instrumentation of the block cache. -/// -/// FileBlockCacheStatsInterface and its subclasses must be safe to use from -/// multiple threads concurrently. -/// -/// WARNING! This is an experimental interface that may change or go away at any -/// time. -class FileBlockCacheStatsInterface { - public: - /// Configure is called to provide instrumentation hooks. - /// - /// Note: Configure can be called multiple times (e.g. if the block cache is - /// re-initialized). - virtual void Configure(const FileBlockCache* block_cache) = 0; - - /// RecordBlockLoadRequest is called to record the size of a hit block. - virtual void RecordCacheHitBlockSize(size_t bytes_transferred) = 0; - - /// RecordBlockLoadRequest is called to record the size of a missed block. - virtual void RecordCacheMissBlockSize(size_t bytes_transferred) = 0; - - virtual ~FileBlockCacheStatsInterface() = default; -}; - -/// \brief A block cache of file contents, keyed by {filename, offset}. -/// -/// This class should be shared by read-only random access files on a remote -/// filesystem (e.g. GCS). -class FileBlockCache { - public: - /// The callback executed when a block is not found in the cache, and needs to - /// be fetched from the backing filesystem. This callback is provided when the - /// cache is constructed. The `status` should be `TF_OK` as long as the - /// read from the remote filesystem succeeded (similar to the semantics of the - /// read(2) system call). - typedef std::function - BlockFetcher; - - virtual ~FileBlockCache() {} - - /// Read `n` bytes from `filename` starting at `offset` into `buffer`. This - /// method will set `status` to: - /// - /// 1) The error from the remote filesystem, if the read from the remote - /// filesystem failed. - /// 2) `TF_FAILED_PRECONDITION` if the read from the remote filesystem - /// succeeded, - /// but the read returned a partial block, and the LRU cache contained a - /// block at a higher offset (indicating that the partial block should have - /// been a full block). - /// 3) `TF_OUT_OF_RANGE` if the read from the remote filesystem succeeded, but - /// the file contents do not extend past `offset` and thus nothing was - /// placed in `out`. - /// 4) `TF_OK` otherwise (i.e. the read succeeded, and at least one byte was - /// placed - /// in `buffer`). - /// - /// Caller is responsible for allocating memory for `buffer`. - /// `buffer` will be left unchanged in case of errors. - virtual void Read(const std::string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) = 0; - - // Validate the given file signature with the existing file signature in the - // cache. Returns true if the signature doesn't change or the file did not - // exist before. If the signature changes, update the existing signature with - // the new one and remove the file from cache. - virtual bool ValidateAndUpdateFileSignature(const std::string& filename, - int64_t file_signature) = 0; - - /// Remove all cached blocks for `filename`. - virtual void RemoveFile(const std::string& filename) = 0; - - /// Remove all cached data. - virtual void Flush() = 0; - - /// Accessors for cache parameters. - virtual size_t block_size() const = 0; - virtual size_t max_bytes() const = 0; - virtual uint64_t max_staleness() const = 0; - - /// The current size (in bytes) of the cache. - virtual size_t CacheSize() const = 0; - - // Returns true if the cache is enabled. If false, the BlockFetcher callback - // is always executed during Read. - virtual bool IsCacheEnabled() const = 0; - - void SetStats(FileBlockCacheStatsInterface* stats) { - if (stats == nullptr) { - std::cerr - << "Attempted to monitor a NULL stats object. This may prevent the " - "corresponding monitoring data from being exported"; - return; - } - cache_stats_ = stats; - cache_stats_->Configure(this); - } - - protected: - FileBlockCacheStatsInterface* cache_stats_ = nullptr; // Not owned. -}; - -} // namespace tf_gcs_filesystem - -#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc index 7861a5708b5..e01af918100 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/types/variant.h" #include "google/cloud/storage/client.h" #include "tensorflow/c/env.h" #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" @@ -27,6 +29,27 @@ limitations under the License. // This filesystem will support `gs://` URI schemes. namespace gcs = google::cloud::storage; +// The environment variable that overrides the block size for aligned reads from +// GCS. Specified in MB (e.g. "16" = 16 x 1024 x 1024 = 16777216 bytes). +constexpr char kBlockSize[] = "GCS_READ_CACHE_BLOCK_SIZE_MB"; +constexpr size_t kDefaultBlockSize = 64 * 1024 * 1024; +// The environment variable that overrides the max size of the LRU cache of +// blocks read from GCS. Specified in MB. +constexpr char kMaxCacheSize[] = "GCS_READ_CACHE_MAX_SIZE_MB"; +constexpr size_t kDefaultMaxCacheSize = 0; +// The environment variable that overrides the maximum staleness of cached file +// contents. Once any block of a file reaches this staleness, all cached blocks +// will be evicted on the next read. +constexpr char kMaxStaleness[] = "GCS_READ_CACHE_MAX_STALENESS"; +constexpr uint64_t kDefaultMaxStaleness = 0; + +constexpr char kStatCacheMaxAge[] = "GCS_STAT_CACHE_MAX_AGE"; +constexpr uint64_t kStatCacheDefaultMaxAge = 5; +// The environment variable that overrides the maximum number of entries in the +// Stat cache. +constexpr char kStatCacheMaxEntries[] = "GCS_STAT_CACHE_MAX_ENTRIES"; +constexpr size_t kStatCacheDefaultMaxEntries = 1024; + // How to upload new data when Flush() is called multiple times. // By default the entire file is reuploaded. constexpr char kAppendMode[] = "GCS_APPEND_MODE"; @@ -81,28 +104,16 @@ static void MaybeAppendSlash(std::string* name) { name->push_back('/'); } -// SECTION 1. Implementation for `TF_RandomAccessFile` -// ---------------------------------------------------------------------------- -namespace tf_random_access_file { -typedef struct GCSFile { - const std::string bucket; - const std::string object; - gcs::Client* gcs_client; // not owned -} GCSFile; - -void Cleanup(TF_RandomAccessFile* file) { - auto gcs_file = static_cast(file->plugin_file); - delete gcs_file; -} - -// TODO(vnvo2409): Adding cache. -// `google-cloud-cpp` is working on a feature that we may want to use. -// See https://github.com/googleapis/google-cloud-cpp/issues/4013. -int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, - char* buffer, TF_Status* status) { - auto gcs_file = static_cast(file->plugin_file); - auto stream = gcs_file->gcs_client->ReadObject( - gcs_file->bucket, gcs_file->object, gcs::ReadRange(offset, offset + n)); +// A helper function to actually read the data from GCS. +static int64_t LoadBufferFromGCS(const std::string& path, size_t offset, + size_t buffer_size, char* buffer, + tf_gcs_filesystem::GCSFile* gcs_file, + TF_Status* status) { + std::string bucket, object; + ParseGCSPath(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return -1; + auto stream = gcs_file->gcs_client.ReadObject( + bucket, object, gcs::ReadRange(offset, offset + buffer_size)); TF_SetStatusFromGCSStatus(stream.status(), status); if ((TF_GetCode(status) != TF_OK) && (TF_GetCode(status) != TF_OUT_OF_RANGE)) { @@ -111,16 +122,119 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, int64_t read; if (!absl::SimpleAtoi(stream.headers().find("content-length")->second, &read)) { - TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header"); - return -1; - } - if (read != n) { - TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); + // When we read a file with offset that is bigger than the actual file size. + // GCS will return an empty header (e.g no `content-length` header). In this + // case, we will set read to `0` and continue. + if (TF_GetCode(status) == TF_OUT_OF_RANGE) { + read = 0; + } else { + TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header"); + return -1; + } } + // `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here. + TF_SetStatus(status, TF_OK, ""); stream.read(buffer, read); + read = stream.gcount(); + if (read < buffer_size) { + // Check stat cache to see if we encountered an interrupted read. + tf_gcs_filesystem::GcsFileStat stat; + if (gcs_file->stat_cache->Lookup(path, &stat)) { + if (offset + read < stat.base.length) { + TF_SetStatus(status, TF_INTERNAL, + absl::StrCat("File contents are inconsistent for file: ", + path, " @ ", offset) + .c_str()); + } + } + } return read; } +// SECTION 1. Implementation for `TF_RandomAccessFile` +// ---------------------------------------------------------------------------- +namespace tf_random_access_file { +using ReadFn = + std::function; +typedef struct GCSFile { + const std::string path; + const bool is_cache_enable; + const uint64_t buffer_size; + ReadFn read_fn; + absl::Mutex buffer_mutex; + uint64_t buffer_start ABSL_GUARDED_BY(buffer_mutex); + bool buffer_end_is_past_eof ABSL_GUARDED_BY(buffer_mutex); + std::string buffer ABSL_GUARDED_BY(buffer_mutex); + + GCSFile(std::string path, bool is_cache_enable, uint64_t buffer_size, + ReadFn read_fn) + : path(path), + is_cache_enable(is_cache_enable), + buffer_size(buffer_size), + read_fn(std::move(read_fn)), + buffer_mutex(), + buffer_start(0), + buffer_end_is_past_eof(false), + buffer() {} +} GCSFile; + +void Cleanup(TF_RandomAccessFile* file) { + auto gcs_file = static_cast(file->plugin_file); + delete gcs_file; +} + +// `google-cloud-cpp` is working on a feature that we may want to use. +// See https://github.com/googleapis/google-cloud-cpp/issues/4013. +int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, + char* buffer, TF_Status* status) { + auto gcs_file = static_cast(file->plugin_file); + if (gcs_file->is_cache_enable || n > gcs_file->buffer_size) { + return gcs_file->read_fn(gcs_file->path, offset, n, buffer, status); + } else { + absl::MutexLock l(&gcs_file->buffer_mutex); + size_t buffer_end = gcs_file->buffer_start + gcs_file->buffer.size(); + size_t copy_size = 0; + if (offset < buffer_end && gcs_file->buffer_start) { + copy_size = (std::min)(n, static_cast(buffer_end - offset)); + memcpy(buffer, + gcs_file->buffer.data() + (offset - gcs_file->buffer_start), + copy_size); + } + bool consumed_buffer_to_eof = + offset + copy_size >= buffer_end && gcs_file->buffer_end_is_past_eof; + if (copy_size < n && !consumed_buffer_to_eof) { + gcs_file->buffer_start = offset + copy_size; + gcs_file->buffer.resize(gcs_file->buffer_size); + auto read_fill_buffer = gcs_file->read_fn( + gcs_file->path, gcs_file->buffer_start, gcs_file->buffer_size, + &(gcs_file->buffer[0]), status); + gcs_file->buffer_end_is_past_eof = + (TF_GetCode(status) == TF_OUT_OF_RANGE); + if (read_fill_buffer >= 0) gcs_file->buffer.resize(read_fill_buffer); + if (TF_GetCode(status) != TF_OK && + TF_GetCode(status) != TF_OUT_OF_RANGE) { + // Empty the buffer to avoid caching bad reads. + gcs_file->buffer.resize(0); + return -1; + } + size_t remaining_copy = + (std::min)(n - copy_size, gcs_file->buffer.size()); + memcpy(buffer + copy_size, gcs_file->buffer.data(), remaining_copy); + copy_size += remaining_copy; + } + if (copy_size < n) { + // Forget the end-of-file flag to allow for clients that poll on the + // same file. + gcs_file->buffer_end_is_past_eof = false; + TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); + return copy_size; + } + TF_SetStatus(status, TF_OK, ""); + return copy_size; + } +} + } // namespace tf_random_access_file // SECTION 2. Implementation for `TF_WritableFile` @@ -289,11 +403,87 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region) { // SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem // ---------------------------------------------------------------------------- namespace tf_gcs_filesystem { -// TODO(vnvo2409): Add lazy-loading and customizing parameters. // TODO(vnvo2409): Use partial reponse for better performance. // TODO(vnvo2409): We could do some cleanups like `return TF_SetStatus`. // TODO(vnvo2409): Refactor the filesystem implementation when // https://github.com/googleapis/google-cloud-cpp/issues/4482 is done. +GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client) + : gcs_client(gcs_client), block_cache_lock() { + const char* append_mode = std::getenv(kAppendMode); + compose = (append_mode != nullptr) && (!strcmp(kAppendMode, append_mode)); + + uint64_t value; + block_size = kDefaultBlockSize; + size_t max_bytes = kDefaultMaxCacheSize; + uint64_t max_staleness = kDefaultMaxStaleness; + + // Apply the overrides for the block size (MB), max bytes (MB), and max + // staleness (seconds) if provided. + if (absl::SimpleAtoi(std::getenv(kBlockSize), &value)) { + block_size = value * 1024 * 1024; + } + if (absl::SimpleAtoi(std::getenv(kMaxCacheSize), &value)) { + max_bytes = static_cast(value * 1024 * 1024); + } + if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) { + max_staleness = value; + } + + file_block_cache = std::make_unique( + block_size, max_bytes, max_staleness, + [this](const std::string& filename, size_t offset, size_t buffer_size, + char* buffer, TF_Status* status) { + return LoadBufferFromGCS(filename, offset, buffer_size, buffer, this, + status); + }); + + uint64_t stat_cache_max_age = kStatCacheDefaultMaxAge; + size_t stat_cache_max_entries = kStatCacheDefaultMaxEntries; + if (absl::SimpleAtoi(std::getenv(kStatCacheMaxAge), &value)) { + stat_cache_max_age = value; + } + if (absl::SimpleAtoi(std::getenv(kStatCacheMaxEntries), &value)) { + stat_cache_max_entries = static_cast(value); + } + stat_cache = std::make_unique>( + stat_cache_max_age, stat_cache_max_entries); +} + +GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client, bool compose, + uint64_t block_size, size_t max_bytes, uint64_t max_staleness, + uint64_t stat_cache_max_age, size_t stat_cache_max_entries) + : gcs_client(gcs_client), + compose(compose), + block_cache_lock(), + block_size(block_size) { + file_block_cache = std::make_unique( + block_size, max_bytes, max_staleness, + [this](const std::string& filename, size_t offset, size_t buffer_size, + char* buffer, TF_Status* status) { + return LoadBufferFromGCS(filename, offset, buffer_size, buffer, this, + status); + }); + stat_cache = std::make_unique>( + stat_cache_max_age, stat_cache_max_entries); +} + +void InitTest(TF_Filesystem* filesystem, bool compose, uint64_t block_size, + size_t max_bytes, uint64_t max_staleness, + uint64_t stat_cache_max_age, size_t stat_cache_max_entries, + TF_Status* status) { + google::cloud::StatusOr client = + gcs::Client::CreateDefaultClient(); + if (!client) { + TF_SetStatusFromGCSStatus(client.status(), status); + return; + } + + filesystem->plugin_filesystem = + new GCSFile(std::move(client.value()), compose, block_size, max_bytes, + max_staleness, stat_cache_max_age, stat_cache_max_entries); + TF_SetStatus(status, TF_OK, ""); +} + void Init(TF_Filesystem* filesystem, TF_Status* status) { google::cloud::StatusOr client = gcs::Client::CreateDefaultClient(); @@ -302,12 +492,7 @@ void Init(TF_Filesystem* filesystem, TF_Status* status) { return; } - const char* append_mode = std::getenv(kAppendMode); - bool compose = - (append_mode != nullptr) && (!strcmp(kAppendMode, append_mode)); - - filesystem->plugin_filesystem = - new GCSFile({std::move(client.value()), compose}); + filesystem->plugin_filesystem = new GCSFile(std::move(client.value())); TF_SetStatus(status, TF_OK, ""); } @@ -316,6 +501,19 @@ void Cleanup(TF_Filesystem* filesystem) { delete gcs_file; } +static void UncachedStatForObject(const std::string& bucket, + const std::string& object, GcsFileStat* stat, + gcs::Client* gcs_client, TF_Status* status) { + auto metadata = gcs_client->GetObjectMetadata(bucket, object); + if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status); + stat->generation_number = metadata->generation(); + stat->base.length = metadata->size(); + stat->base.mtime_nsec = + metadata->time_storage_class_updated().time_since_epoch().count(); + stat->base.is_directory = object.back() == '/'; + return TF_SetStatus(status, TF_OK, ""); +} + // TODO(vnvo2409): Implement later void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, TF_RandomAccessFile* file, TF_Status* status) { @@ -324,8 +522,46 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, if (TF_GetCode(status) != TF_OK) return; auto gcs_file = static_cast(filesystem->plugin_filesystem); + bool is_cache_enabled; + { + absl::MutexLock l(&gcs_file->block_cache_lock); + is_cache_enabled = gcs_file->file_block_cache->IsCacheEnabled(); + } + auto read_fn = [gcs_file, is_cache_enabled, bucket, object]( + const std::string& path, uint64_t offset, size_t n, + char* buffer, TF_Status* status) -> int64_t { + int64_t read = 0; + if (is_cache_enabled) { + absl::ReaderMutexLock l(&gcs_file->block_cache_lock); + GcsFileStat stat; + gcs_file->stat_cache->LookupOrCompute( + path, &stat, + [gcs_file, bucket, object](const std::string& path, GcsFileStat* stat, + TF_Status* status) { + UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client, + status); + }, + status); + if (TF_GetCode(status) != TF_OK) return -1; + if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature( + path, stat.generation_number)) { + std::cout + << "File signature has been changed. Refreshing the cache. Path: " + << path; + } + read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status); + } else { + read = LoadBufferFromGCS(path, offset, n, buffer, gcs_file, status); + } + if (TF_GetCode(status) != TF_OK) return -1; + if (read < n) + TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); + else + TF_SetStatus(status, TF_OK, ""); + return read; + }; file->plugin_file = new tf_random_access_file::GCSFile( - {std::move(bucket), std::move(object), &gcs_file->gcs_client}); + std::move(path), is_cache_enabled, gcs_file->block_size, read_fn); TF_SetStatus(status, TF_OK, ""); } @@ -428,28 +664,179 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, } } -void CreateDir(const TF_Filesystem* filesystem, const char* path, - TF_Status* status) { +static void StatForObject(GCSFile* gcs_file, const std::string& path, + const std::string& bucket, const std::string& object, + GcsFileStat* stat, TF_Status* status) { + if (object.empty()) + return TF_SetStatus( + status, TF_INVALID_ARGUMENT, + ("'object' must be a non-empty string. (File: " + path + ")").c_str()); + TF_SetStatus(status, TF_OK, ""); + gcs_file->stat_cache->LookupOrCompute( + path, stat, + [gcs_file, bucket, object](const std::string& path, GcsFileStat* stat, + TF_Status* status) { + UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client, + status); + }, + status); +} + +static bool ObjectExists(GCSFile* gcs_file, const std::string& path, + const std::string& bucket, const std::string& object, + TF_Status* status) { + GcsFileStat stat; + StatForObject(gcs_file, path, bucket, object, &stat, status); + if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND) + return false; + if (TF_GetCode(status) == TF_NOT_FOUND) { + TF_SetStatus(status, TF_OK, ""); + return false; + } + return !stat.base.is_directory; +} + +static bool BucketExists(GCSFile* gcs_file, const std::string& bucket, + TF_Status* status) { + auto metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); + TF_SetStatusFromGCSStatus(metadata.status(), status); + if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND) + return false; + if (TF_GetCode(status) == TF_NOT_FOUND) { + TF_SetStatus(status, TF_OK, ""); + return false; + } + return true; +} + +static std::vector GetChildrenBounded( + GCSFile* gcs_file, std::string dir, uint64_t max_results, bool recursive, + bool include_self_directory_marker, TF_Status* status) { + std::string bucket, prefix; + MaybeAppendSlash(&dir); + ParseGCSPath(dir, true, &bucket, &prefix, status); + + std::vector result; + uint64_t count = 0; + std::string delimiter = recursive ? "" : "/"; + + for (auto&& item : gcs_file->gcs_client.ListObjectsAndPrefixes( + bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter))) { + if (count == max_results) { + TF_SetStatus(status, TF_OK, ""); + return result; + } + if (!item) { + TF_SetStatusFromGCSStatus(item.status(), status); + return result; + } + auto value = *std::move(item); + std::string children = absl::holds_alternative(value) + ? absl::get(value) + : absl::get(value).name(); + auto pos = children.find(prefix); + if (pos != 0) { + TF_SetStatus(status, TF_INTERNAL, + ("Unexpected response: the returned file name " + children + + " doesn't match the prefix " + prefix) + .c_str()); + return result; + } + children.erase(0, prefix.length()); + if (!children.empty() || include_self_directory_marker) { + result.emplace_back(children); + } + ++count; + } + + return result; +} + +static bool FolderExists(GCSFile* gcs_file, std::string dir, + TF_Status* status) { + ExpiringLRUCache::ComputeFunc compute_func = + [gcs_file](const std::string& dir, GcsFileStat* stat, TF_Status* status) { + auto children = + GetChildrenBounded(gcs_file, dir, 1, true, true, status); + if (TF_GetCode(status) != TF_OK) return; + if (!children.empty()) { + stat->base = {0, 0, true}; + return TF_SetStatus(status, TF_OK, ""); + } else { + return TF_SetStatus(status, TF_INVALID_ARGUMENT, "Not a directory!"); + } + }; + GcsFileStat stat; + MaybeAppendSlash(&dir); + gcs_file->stat_cache->LookupOrCompute(dir, &stat, compute_func, status); + if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_INVALID_ARGUMENT) + return false; + if (TF_GetCode(status) == TF_INVALID_ARGUMENT) { + TF_SetStatus(status, TF_OK, ""); + return false; + } + return true; +} + +static void ClearFileCaches(GCSFile* gcs_file, const std::string& path) { + absl::ReaderMutexLock l(&gcs_file->block_cache_lock); + gcs_file->file_block_cache->RemoveFile(path); + gcs_file->stat_cache->Delete(path); +} + +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { std::string bucket, object; ParseGCSPath(path, true, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; + auto gcs_file = static_cast(filesystem->plugin_filesystem); if (object.empty()) { - auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); - TF_SetStatusFromGCSStatus(bucket_metadata.status(), status); + bool result = BucketExists(gcs_file, bucket, status); + if (result) return TF_SetStatus(status, TF_OK, ""); + } + + GcsFileStat stat; + StatForObject(gcs_file, path, bucket, object, &stat, status); + if (TF_GetCode(status) != TF_NOT_FOUND) return; + + bool result = FolderExists(gcs_file, path, status); + if (TF_GetCode(status) != TF_OK || (TF_GetCode(status) == TF_OK && result)) + return; + return TF_SetStatus( + status, TF_NOT_FOUND, + absl::StrCat("The path ", path, " does not exist.").c_str()); +} + +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + std::string dir = path; + MaybeAppendSlash(&dir); + std::string bucket, object; + ParseGCSPath(dir, true, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + auto gcs_file = static_cast(filesystem->plugin_filesystem); + if (object.empty()) { + bool is_directory = BucketExists(gcs_file, bucket, status); + if (TF_GetCode(status) != TF_OK) return; + if (!is_directory) + TF_SetStatus(status, TF_NOT_FOUND, + ("The specified bucket " + dir + " was not found.").c_str()); return; } - MaybeAppendSlash(&object); - auto object_metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); - TF_SetStatusFromGCSStatus(object_metadata.status(), status); - if (TF_GetCode(status) == TF_NOT_FOUND) { - auto insert_metadata = - gcs_file->gcs_client.InsertObject(bucket, object, ""); - TF_SetStatusFromGCSStatus(insert_metadata.status(), status); - } else if (TF_GetCode(status) == TF_OK) { + PathExists(filesystem, dir.c_str(), status); + if (TF_GetCode(status) == TF_OK) + return TF_SetStatus(status, TF_ALREADY_EXISTS, path); + + auto metadata = gcs_file->gcs_client.InsertObject( + bucket, object, "", + // Adding this parameter means HTTP_CODE_PRECONDITION_FAILED + // will be returned if the object already exists, so avoid reuploading. + gcs::IfGenerationMatch(0)); + TF_SetStatusFromGCSStatus(metadata.status(), status); + if (TF_GetCode(status) == TF_FAILED_PRECONDITION) TF_SetStatus(status, TF_ALREADY_EXISTS, path); - } } // TODO(vnvo2409): `RecursivelyCreateDir` should use `CreateDir` instead of the @@ -465,79 +852,31 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path, auto gcs_file = static_cast(filesystem->plugin_filesystem); auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object); TF_SetStatusFromGCSStatus(gcs_status, status); + if (TF_GetCode(status) == TF_OK) ClearFileCaches(gcs_file, path); } +// Checks that the directory is empty (i.e no objects with this prefix exist). +// Deletes the GCS directory marker if it exists. void DeleteDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { - std::string bucket, object; - ParseGCSPath(path, false, &bucket, &object, status); - if (TF_GetCode(status) != TF_OK) return; - MaybeAppendSlash(&object); + // A directory is considered empty either if there are no matching objects + // with the corresponding name prefix or if there is exactly one matching + // object and it is the directory marker. Therefore we need to retrieve + // at most two children for the prefix to detect if a directory is empty. auto gcs_file = static_cast(filesystem->plugin_filesystem); - int object_count = 0; - for (auto&& metadata : - gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) { - if (!metadata) { - TF_SetStatusFromGCSStatus(metadata.status(), status); - return; - } - ++object_count; - // We consider a path is a non-empty directory in two cases: - // - There are more than two objects whose keys start with the name of this - // directory. - // - There is one object whose key contains the name of this directory ( but - // not equal ). - if (object_count > 1 || metadata->name() != object) { - TF_SetStatus(status, TF_FAILED_PRECONDITION, - "Cannot delete a non-empty directory."); - return; - } - } - auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object); - TF_SetStatusFromGCSStatus(gcs_status, status); -} - -// TODO(vnvo2409): `DeleteRecursively` needs `GetChildrens` but there will be -// some differents compared to the default implementation. Will be refactored. -static void DeleteRecursively(const TF_Filesystem* filesystem, const char* path, - uint64_t* undeleted_files, - uint64_t* undeleted_dirs, TF_Status* status) { - std::string bucket, object; - ParseGCSPath(path, false, &bucket, &object, status); + auto childrens = GetChildrenBounded(gcs_file, path, 2, true, true, status); if (TF_GetCode(status) != TF_OK) return; - - auto gcs_file = static_cast(filesystem->plugin_filesystem); - auto gcs_status = gcs::DeleteByPrefix(gcs_file->gcs_client, bucket, object); - TF_SetStatusFromGCSStatus(gcs_status, status); - if (TF_GetCode(status) != TF_OK) return; - *undeleted_dirs = 0; - *undeleted_files = 0; -} - -// TODO(vnvo2409): `RewriteObjectBlocking` will set `status` to `TF_NOT_FOUND` -// if the object does not exist. In that case, we will have to check if the -// `src` is a directory or not to set the correspondent `status` (i.e -// `TF_NOT_FOUND` if path `src` does not exist, `TF_FAILED_PRECONDITION` if -// path `src` is a directory). -void RenameFile(const TF_Filesystem* filesystem, const char* src, - const char* dst, TF_Status* status) { - std::string bucket_src, object_src; - ParseGCSPath(src, false, &bucket_src, &object_src, status); - if (TF_GetCode(status) != TF_OK) return; - - std::string bucket_dst, object_dst; - ParseGCSPath(dst, false, &bucket_dst, &object_dst, status); - if (TF_GetCode(status) != TF_OK) return; - - auto gcs_file = static_cast(filesystem->plugin_filesystem); - auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( - bucket_src, object_src, bucket_dst, object_dst); - if (!metadata) { - TF_SetStatusFromGCSStatus(metadata.status(), status); + if (childrens.size() > 1 || (childrens.size() == 1 && !childrens[0].empty())) + return TF_SetStatus(status, TF_FAILED_PRECONDITION, + "Cannot delete a non-empty directory."); + if (childrens.size() == 1 && childrens[0].empty()) { + // This is the directory marker object. Delete it. + std::string dir = path; + MaybeAppendSlash(&dir); + DeleteFile(filesystem, dir.c_str(), status); return; } - auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket_src, object_src); - TF_SetStatusFromGCSStatus(gcs_status, status); + TF_SetStatus(status, TF_OK, ""); } void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, @@ -556,6 +895,183 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, TF_SetStatusFromGCSStatus(metadata.status(), status); } +bool IsDirectory(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + std::string bucket, object; + ParseGCSPath(path, true, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return false; + + auto gcs_file = static_cast(filesystem->plugin_filesystem); + if (object.empty()) { + bool result = BucketExists(gcs_file, bucket, status); + if (TF_GetCode(status) != TF_OK) return false; + if (!result) + TF_SetStatus( + status, TF_NOT_FOUND, + ("The specified bucket gs://" + bucket + " was not found.").c_str()); + return result; + } + + bool is_folder = FolderExists(gcs_file, path, status); + if (TF_GetCode(status) != TF_OK) return false; + if (is_folder) return true; + + bool is_object = ObjectExists(gcs_file, path, bucket, object, status); + if (TF_GetCode(status) != TF_OK) return false; + if (is_object) { + TF_SetStatus( + status, TF_FAILED_PRECONDITION, + absl::StrCat("The specified path ", path, " is not a directory.") + .c_str()); + return false; + } + TF_SetStatus(status, TF_NOT_FOUND, + absl::StrCat("The path ", path, " does not exist.").c_str()); + return false; +} + +static void RenameObject(const TF_Filesystem* filesystem, + const std::string& src, const std::string& dst, + TF_Status* status) { + std::string bucket_src, object_src; + ParseGCSPath(src, false, &bucket_src, &object_src, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string bucket_dst, object_dst; + ParseGCSPath(dst, false, &bucket_dst, &object_dst, status); + if (TF_GetCode(status) != TF_OK) return; + + auto gcs_file = static_cast(filesystem->plugin_filesystem); + auto metadata = gcs_file->gcs_client.RewriteObjectBlocking( + bucket_src, object_src, bucket_dst, object_dst); + TF_SetStatusFromGCSStatus(metadata.status(), status); + if (TF_GetCode(status) != TF_OK) return; + + ClearFileCaches(gcs_file, dst); + DeleteFile(filesystem, src.c_str(), status); +} + +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status) { + if (!IsDirectory(filesystem, src, status)) { + if (TF_GetCode(status) == TF_FAILED_PRECONDITION) + RenameObject(filesystem, src, dst, status); + return; + } + + auto gcs_file = static_cast(filesystem->plugin_filesystem); + std::vector childrens = + GetChildrenBounded(gcs_file, src, UINT64_MAX, true, true, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string src_dir = src; + std::string dst_dir = dst; + MaybeAppendSlash(&src_dir); + MaybeAppendSlash(&dst_dir); + for (const std::string& children : childrens) { + RenameObject(filesystem, src_dir + children, dst_dir + children, status); + if (TF_GetCode(status) != TF_OK) return; + } + TF_SetStatus(status, TF_OK, ""); +} + +void DeleteRecursively(const TF_Filesystem* filesystem, const char* path, + uint64_t* undeleted_files, uint64_t* undeleted_dirs, + TF_Status* status) { + if (!undeleted_files || !undeleted_dirs) + return TF_SetStatus( + status, TF_INTERNAL, + "'undeleted_files' and 'undeleted_dirs' cannot be nullptr."); + *undeleted_files = 0; + *undeleted_dirs = 0; + if (!IsDirectory(filesystem, path, status)) { + *undeleted_dirs = 1; + return; + } + auto gcs_file = static_cast(filesystem->plugin_filesystem); + std::vector childrens = + GetChildrenBounded(gcs_file, path, UINT64_MAX, true, true, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string dir = path; + MaybeAppendSlash(&dir); + for (const std::string& children : childrens) { + const std::string& full_path = dir + children; + DeleteFile(filesystem, full_path.c_str(), status); + if (TF_GetCode(status) != TF_OK) { + if (IsDirectory(filesystem, full_path.c_str(), status)) + // The object is a directory marker. + (*undeleted_dirs)++; + else + (*undeleted_files)++; + } + } +} + +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status) { + auto gcs_file = static_cast(filesystem->plugin_filesystem); + std::vector childrens = + GetChildrenBounded(gcs_file, path, UINT64_MAX, false, false, status); + if (TF_GetCode(status) != TF_OK) return -1; + + int num_entries = childrens.size(); + *entries = static_cast( + plugin_memory_allocate(num_entries * sizeof((*entries)[0]))); + for (int i = 0; i < num_entries; i++) + (*entries)[i] = strdup(childrens[i].c_str()); + TF_SetStatus(status, TF_OK, ""); + return num_entries; +} + +void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status) { + std::string bucket, object; + ParseGCSPath(path, true, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + + auto gcs_file = static_cast(filesystem->plugin_filesystem); + if (object.empty()) { + auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket); + TF_SetStatusFromGCSStatus(bucket_metadata.status(), status); + if (TF_GetCode(status) == TF_OK) { + stats->is_directory = true; + stats->length = 0; + stats->mtime_nsec = 0; + } + return; + } + if (IsDirectory(filesystem, path, status)) { + stats->is_directory = true; + stats->length = 0; + stats->mtime_nsec = 0; + return TF_SetStatus(status, TF_OK, ""); + } + if (TF_GetCode(status) == TF_OK) { + auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object); + if (metadata) { + stats->is_directory = false; + stats->length = metadata.value().size(); + stats->mtime_nsec = metadata.value() + .time_storage_class_updated() + .time_since_epoch() + .count(); + } + TF_SetStatusFromGCSStatus(metadata.status(), status); + } +} + +static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) { + return strdup(uri); +} + +static void FlushCaches(const TF_Filesystem* filesystem) { + auto gcs_file = static_cast(filesystem->plugin_filesystem); + absl::ReaderMutexLock l(&gcs_file->block_cache_lock); + gcs_file->file_block_cache->Flush(); + gcs_file->stat_cache->Clear(); +} + } // namespace tf_gcs_filesystem static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, @@ -572,6 +1088,13 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE)); ops->writable_file_ops->cleanup = tf_writable_file::Cleanup; + ops->read_only_memory_region_ops = static_cast( + plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE)); + ops->read_only_memory_region_ops->cleanup = + tf_read_only_memory_region::Cleanup; + ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data; + ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length; + ops->filesystem_ops = static_cast( plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE)); ops->filesystem_ops->init = tf_gcs_filesystem::Init; @@ -581,6 +1104,20 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile; ops->filesystem_ops->new_appendable_file = tf_gcs_filesystem::NewAppendableFile; + ops->filesystem_ops->new_read_only_memory_region_from_file = + tf_gcs_filesystem::NewReadOnlyMemoryRegionFromFile; + ops->filesystem_ops->create_dir = tf_gcs_filesystem::CreateDir; + ops->filesystem_ops->delete_file = tf_gcs_filesystem::DeleteFile; + ops->filesystem_ops->delete_dir = tf_gcs_filesystem::DeleteDir; + ops->filesystem_ops->delete_recursively = + tf_gcs_filesystem::DeleteRecursively; + ops->filesystem_ops->copy_file = tf_gcs_filesystem::CopyFile; + ops->filesystem_ops->path_exists = tf_gcs_filesystem::PathExists; + ops->filesystem_ops->is_directory = tf_gcs_filesystem::IsDirectory; + ops->filesystem_ops->stat = tf_gcs_filesystem::Stat; + ops->filesystem_ops->get_children = tf_gcs_filesystem::GetChildren; + ops->filesystem_ops->translate_name = tf_gcs_filesystem::TranslateName; + ops->filesystem_ops->flush_caches = tf_gcs_filesystem::FlushCaches; } void TF_InitPlugin(TF_FilesystemPluginInfo* info) { diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h index 93862f4a871..973ce9e9dc2 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h @@ -17,6 +17,8 @@ #include "google/cloud/storage/client.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h" +#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h" #include "tensorflow/c/tf_status.h" void ParseGCSPath(const std::string& fname, bool object_empty_ok, @@ -45,10 +47,34 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region); } // namespace tf_read_only_memory_region namespace tf_gcs_filesystem { +typedef struct GcsFileStat { + TF_FileStatistics base; + int64_t generation_number; +} GcsFileStat; + typedef struct GCSFile { google::cloud::storage::Client gcs_client; // owned bool compose; + absl::Mutex block_cache_lock; + std::shared_ptr file_block_cache + ABSL_GUARDED_BY(block_cache_lock); + uint64_t block_size; // Reads smaller than block_size will trigger a read + // of block_size. + std::unique_ptr> stat_cache; + GCSFile(google::cloud::storage::Client&& gcs_client); + // This constructor is used for testing purpose only. + GCSFile(google::cloud::storage::Client&& gcs_client, bool compose, + uint64_t block_size, size_t max_bytes, uint64_t max_staleness, + uint64_t stat_cache_max_age, size_t stat_cache_max_entries); } GCSFile; + +// This function is used to initialize a filesystem without the need of setting +// manually environement variables. +void InitTest(TF_Filesystem* filesystem, bool compose, uint64_t block_size, + size_t max_bytes, uint64_t max_staleness, + uint64_t stat_cache_max_age, size_t stat_cache_max_entries, + TF_Status* status); + void Init(TF_Filesystem* filesystem, TF_Status* status); void Cleanup(TF_Filesystem* filesystem); void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc index 0e3c855d6c6..82c4e4b8705 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc @@ -66,6 +66,9 @@ static std::string* GetTmpDir() { namespace tensorflow { namespace { +// TODO(vnvo2409): Refactor `gcs_filesystem_test` to remove unnecessary tests +// after porting all tests from +// `//tensorflow/core/platform/cloud:gcs_file_system_test`. class GCSFilesystemTest : public ::testing::Test { public: void SetUp() override { @@ -74,13 +77,14 @@ class GCSFilesystemTest : public ::testing::Test { ::testing::UnitTest::GetInstance()->current_test_info()->name()); status_ = TF_NewStatus(); filesystem_ = new TF_Filesystem; - tf_gcs_filesystem::Init(filesystem_, status_); - ASSERT_TF_OK(status_) << "Could not initialize filesystem. " - << TF_Message(status_); + filesystem_->plugin_filesystem = nullptr; + // Because different tests requires different setup for filesystem. We + // initialize filesystem in each testcase. } void TearDown() override { TF_DeleteStatus(status_); - tf_gcs_filesystem::Cleanup(filesystem_); + if (filesystem_->plugin_filesystem != nullptr) + tf_gcs_filesystem::Cleanup(filesystem_); delete filesystem_; } @@ -117,6 +121,21 @@ class GCSFilesystemTest : public ::testing::Test { } } +::testing::AssertionResult InsertObject(const std::string& path, + const std::string& content, + gcs::Client* gcs_client, + TF_Status* status) { + std::string bucket, object; + ParseGCSPath(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) + return ::testing::AssertionFailure() << TF_Message(status); + auto metadata = gcs_client->InsertObject(bucket, object, content); + if (metadata) + return ::testing::AssertionSuccess(); + else + return ::testing::AssertionFailure() << metadata.status().message(); +} + ::testing::AssertionResult CompareSubString(int64_t offset, size_t length, absl::string_view result, size_t read) { @@ -172,6 +191,9 @@ TEST_F(GCSFilesystemTest, ParseGCSPath) { } TEST_F(GCSFilesystemTest, RandomAccessFile) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_) << "Could not initialize filesystem. " + << TF_Message(status_); std::string filepath = GetURIForPath("a_file"); TF_RandomAccessFile* file = new TF_RandomAccessFile; tf_gcs_filesystem::NewRandomAccessFile(filesystem_, filepath.c_str(), file, @@ -208,6 +230,9 @@ TEST_F(GCSFilesystemTest, RandomAccessFile) { } TEST_F(GCSFilesystemTest, WritableFile) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_) << "Could not initialize filesystem. " + << TF_Message(status_); std::string filepath = GetURIForPath("a_file"); TF_WritableFile* file = new TF_WritableFile; tf_gcs_filesystem::NewWritableFile(filesystem_, filepath.c_str(), file, @@ -273,6 +298,9 @@ TEST_F(GCSFilesystemTest, WritableFile) { } TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) { + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_) << "Could not initialize filesystem. " + << TF_Message(status_); std::string path = GetURIForPath("a_file"); auto gcs_file = static_cast(filesystem_->plugin_filesystem); @@ -298,6 +326,131 @@ TEST_F(GCSFilesystemTest, ReadOnlyMemoryRegion) { delete region; } +// These tests below are ported from +// `//tensorflow/core/platform/cloud:gcs_file_system_test` +TEST_F(GCSFilesystemTest, NewRandomAccessFile_NoBlockCache) { + tf_gcs_filesystem::InitTest(filesystem_, false, 0, 0, 0, 0, 0, status_); + ASSERT_TF_OK(status_) << "Could not initialize filesystem. " + << TF_Message(status_); + std::string path = GetURIForPath("a_file"); + auto gcs_file = + static_cast(filesystem_->plugin_filesystem); + ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_)); + + TF_RandomAccessFile* file = new TF_RandomAccessFile; + tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file, + status_); + ASSERT_TF_OK(status_); + + std::string result; + result.resize(6); + int64_t read = tf_random_access_file::Read(file, 0, 6, &result[0], status_); + ASSERT_EQ(read, 6) << "Read: " << read << "\n"; + ASSERT_TF_OK(status_); + ASSERT_EQ(result, "012345") << "Result: " << result << "\n"; + + read = tf_random_access_file::Read(file, 6, 6, &result[0], status_); + ASSERT_EQ(read, 4) << "Read: " << read << "\n"; + ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_); + result.resize(read); + ASSERT_EQ(result, "6789") << "Result: " << result << "\n"; +} + +TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered) { + tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_); + ASSERT_TF_OK(status_) << "Could not initialize filesystem. " + << TF_Message(status_); + std::string path = GetURIForPath("a_file"); + auto gcs_file = + static_cast(filesystem_->plugin_filesystem); + ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_)); + + TF_RandomAccessFile* file = new TF_RandomAccessFile; + tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file, + status_); + ASSERT_TF_OK(status_); + + std::string result; + result.resize(6); + int64_t read = tf_random_access_file::Read(file, 0, 6, &result[0], status_); + ASSERT_EQ(read, 6) << "Read: " << read << "\n"; + ASSERT_TF_OK(status_); + ASSERT_EQ(result, "012345") << "Result: " << result << "\n"; + + read = tf_random_access_file::Read(file, 6, 6, &result[0], status_); + ASSERT_EQ(read, 4) << "Read: " << read << "\n"; + ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_); + result.resize(read); + ASSERT_EQ(result, "6789") << "Result: " << result << "\n"; +} + +TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered_ReadAtEOF) { + tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_); + ASSERT_TF_OK(status_) << "Could not initialize filesystem. " + << TF_Message(status_); + std::string path = GetURIForPath("a_file"); + auto gcs_file = + static_cast(filesystem_->plugin_filesystem); + ASSERT_TRUE(InsertObject(path, "0123456789", &gcs_file->gcs_client, status_)); + + TF_RandomAccessFile* file = new TF_RandomAccessFile; + tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file, + status_); + ASSERT_TF_OK(status_); + + std::string result; + result.resize(10); + int64_t read = tf_random_access_file::Read(file, 0, result.length(), + &result[0], status_); + ASSERT_EQ(read, 10) << "Read: " << read << "\n"; + ASSERT_TF_OK(status_); + ASSERT_EQ(result, "0123456789") << "Result: " << result << "\n"; + + read = tf_random_access_file::Read(file, result.length(), result.length(), + &result[0], status_); + ASSERT_EQ(read, 0) << "Read: " << read << "\n"; + ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_); + result.resize(read); + ASSERT_EQ(result, "") << "Result: " << result << "\n"; +} + +TEST_F(GCSFilesystemTest, NewRandomAccessFile_Buffered_CachedOutOfRange) { + tf_gcs_filesystem::InitTest(filesystem_, false, 10, 0, 0, 0, 0, status_); + ASSERT_TF_OK(status_) << "Could not initialize filesystem. " + << TF_Message(status_); + std::string path = GetURIForPath("a_file"); + auto gcs_file = + static_cast(filesystem_->plugin_filesystem); + ASSERT_TRUE(InsertObject(path, "012345678", &gcs_file->gcs_client, status_)); + + TF_RandomAccessFile* file = new TF_RandomAccessFile; + tf_gcs_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), file, + status_); + ASSERT_TF_OK(status_); + + std::string result; + result.resize(5); + int64_t read = tf_random_access_file::Read(file, 0, result.length(), + &result[0], status_); + ASSERT_EQ(read, 5) << "Read: " << read << "\n"; + ASSERT_TF_OK(status_); + ASSERT_EQ(result, "01234") << "Result: " << result << "\n"; + + read = tf_random_access_file::Read(file, 4, result.length(), &result[0], + status_); + ASSERT_EQ(read, 5) << "Read: " << read << "\n"; + ASSERT_TF_OK(status_); + result.resize(read); + ASSERT_EQ(result, "45678") << "Result: " << result << "\n"; + + read = tf_random_access_file::Read(file, 5, result.length(), &result[0], + status_); + ASSERT_EQ(read, 4) << "Read: " << read << "\n"; + ASSERT_EQ(TF_GetCode(status_), TF_OUT_OF_RANGE) << TF_Message(status_); + result.resize(read); + ASSERT_EQ(result, "5678") << "Result: " << result << "\n"; +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc index 102c7fa175c..3700ccf17a2 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc @@ -39,9 +39,6 @@ std::shared_ptr RamFileBlockCache::Lookup( auto entry = block_map_.find(key); if (entry != block_map_.end()) { if (BlockNotStale(entry->second)) { - if (cache_stats_ != nullptr) { - cache_stats_->RecordCacheHitBlockSize(entry->second->data.size()); - } return entry->second; } else { // Remove the stale block and continue. @@ -136,12 +133,9 @@ void RamFileBlockCache::MaybeFetch(const Key& key, block->mu.Unlock(); // Release the lock while making the API call. block->data.clear(); block->data.resize(block_size_, 0); - size_t bytes_transferred; - block_fetcher_(key.first, key.second, block_size_, block->data.data(), - &bytes_transferred, status); - if (cache_stats_ != nullptr) { - cache_stats_->RecordCacheMissBlockSize(bytes_transferred); - } + int64_t bytes_transferred; + bytes_transferred = block_fetcher_(key.first, key.second, block_size_, + block->data.data(), status); block->mu.Lock(); // Reacquire the lock immediately afterwards if (TF_GetCode(status) == TF_OK) { block->data.resize(bytes_transferred, 0); @@ -171,18 +165,16 @@ void RamFileBlockCache::MaybeFetch(const Key& key, "Control flow should never reach the end of RamFileBlockCache::Fetch."); } -void RamFileBlockCache::Read(const std::string& filename, size_t offset, - size_t n, char* buffer, size_t* bytes_transferred, - TF_Status* status) { - *bytes_transferred = 0; +int64_t RamFileBlockCache::Read(const std::string& filename, size_t offset, + size_t n, char* buffer, TF_Status* status) { if (n == 0) { - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return 0; } if (!IsCacheEnabled() || (n > max_bytes_)) { // The cache is effectively disabled, so we pass the read through to the // fetcher without breaking it up into blocks. - return block_fetcher_(filename, offset, n, buffer, bytes_transferred, - status); + return block_fetcher_(filename, offset, n, buffer, status); } // Calculate the block-aligned start and end of the read. size_t start = block_size_ * (offset / block_size_); @@ -202,20 +194,20 @@ void RamFileBlockCache::Read(const std::string& filename, size_t offset, abort(); } MaybeFetch(key, block, status); - if (TF_GetCode(status) != TF_OK) return; + if (TF_GetCode(status) != TF_OK) return -1; UpdateLRU(key, block, status); - if (TF_GetCode(status) != TF_OK) return; + if (TF_GetCode(status) != TF_OK) return -1; // Copy the relevant portion of the block into the result buffer. const auto& data = block->data; if (offset >= pos + data.size()) { // The requested offset is at or beyond the end of the file. This can // happen if `offset` is not block-aligned, and the read returns the last // block in the file, which does not extend all the way out to `offset`. - *bytes_transferred = total_bytes_transferred; std::stringstream os; os << "EOF at offset " << offset << " in file " << filename << " at position " << pos << " with data size " << data.size(); - return TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str()); + TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str()); + return total_bytes_transferred; } auto begin = data.begin(); if (offset > pos) { @@ -237,8 +229,8 @@ void RamFileBlockCache::Read(const std::string& filename, size_t offset, break; } } - *bytes_transferred = total_bytes_transferred; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return total_bytes_transferred; } bool RamFileBlockCache::ValidateAndUpdateFileSignature( diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h index 5a82f65db41..2abfb6f924b 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h @@ -28,7 +28,6 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" #include "tensorflow/c/env.h" -#include "tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h" #include "tensorflow/c/tf_status.h" namespace tf_gcs_filesystem { @@ -37,16 +36,17 @@ namespace tf_gcs_filesystem { /// /// This class should be shared by read-only random access files on a remote /// filesystem (e.g. GCS). -class RamFileBlockCache : public FileBlockCache { +class RamFileBlockCache { public: /// The callback executed when a block is not found in the cache, and needs to /// be fetched from the backing filesystem. This callback is provided when the - /// cache is constructed. The `status` should be `TF_OK` as long as the - /// read from the remote filesystem succeeded (similar to the semantics of the - /// read(2) system call). - typedef std::function + /// cache is constructed. It returns total bytes read ( -1 in case of errors + /// ). The `status` should be `TF_OK` as long as the read from the remote + /// filesystem succeeded (similar to the semantics of the read(2) system + /// call). + typedef std::function BlockFetcher; RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness, @@ -66,10 +66,10 @@ class RamFileBlockCache : public FileBlockCache { TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this)); } std::cout << "GCS file block cache is " - << (IsCacheEnabled() ? "enabled" : "disabled"); + << (IsCacheEnabled() ? "enabled" : "disabled") << ".\n"; } - ~RamFileBlockCache() override { + ~RamFileBlockCache() { if (pruning_thread_) { stop_pruning_thread_.Notify(); // Destroying pruning_thread_ will block until Prune() receives the above @@ -78,8 +78,9 @@ class RamFileBlockCache : public FileBlockCache { } } - /// Read `n` bytes from `filename` starting at `offset` into `buffer`. This - /// method will set `status` to: + /// Read `n` bytes from `filename` starting at `offset` into `buffer`. It + /// returns total bytes read ( -1 in case of errors ). This method will set + /// `status` to: /// /// 1) The error from the remote filesystem, if the read from the remote /// filesystem failed. @@ -97,37 +98,34 @@ class RamFileBlockCache : public FileBlockCache { /// /// Caller is responsible for allocating memory for `buffer`. /// `buffer` will be left unchanged in case of errors. - void Read(const std::string& filename, size_t offset, size_t n, char* buffer, - size_t* bytes_transferred, TF_Status* status) override; + int64_t Read(const std::string& filename, size_t offset, size_t n, + char* buffer, TF_Status* status); // Validate the given file signature with the existing file signature in the // cache. Returns true if the signature doesn't change or the file doesn't // exist before. If the signature changes, update the existing signature with // the new one and remove the file from cache. bool ValidateAndUpdateFileSignature(const std::string& filename, - int64_t file_signature) override + int64_t file_signature) ABSL_LOCKS_EXCLUDED(mu_); /// Remove all cached blocks for `filename`. - void RemoveFile(const std::string& filename) override - ABSL_LOCKS_EXCLUDED(mu_); + void RemoveFile(const std::string& filename) ABSL_LOCKS_EXCLUDED(mu_); /// Remove all cached data. - void Flush() override ABSL_LOCKS_EXCLUDED(mu_); + void Flush() ABSL_LOCKS_EXCLUDED(mu_); /// Accessors for cache parameters. - size_t block_size() const override { return block_size_; } - size_t max_bytes() const override { return max_bytes_; } - uint64_t max_staleness() const override { return max_staleness_; } + size_t block_size() const { return block_size_; } + size_t max_bytes() const { return max_bytes_; } + uint64_t max_staleness() const { return max_staleness_; } /// The current size (in bytes) of the cache. - size_t CacheSize() const override ABSL_LOCKS_EXCLUDED(mu_); + size_t CacheSize() const ABSL_LOCKS_EXCLUDED(mu_); // Returns true if the cache is enabled. If false, the BlockFetcher callback // is always executed during Read. - bool IsCacheEnabled() const override { - return block_size_ > 0 && max_bytes_ > 0; - } + bool IsCacheEnabled() const { return block_size_ > 0 && max_bytes_ > 0; } // We can not pass a lambda with capture as a function pointer to // `TF_StartThread`, so we have to wrap `Prune` inside a static function. diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc index b1ea295c080..859d42d85e3 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache_test.cc @@ -33,20 +33,22 @@ Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache, std::vector* out) { out->clear(); out->resize(n, 0); - size_t bytes_transferred = 0; TF_Status status; - cache->Read(filename, offset, n, out->data(), &bytes_transferred, &status); - EXPECT_LE(bytes_transferred, n); - out->resize(bytes_transferred, n); + auto bytes_transferred = + cache->Read(filename, offset, n, out->data(), &status); + if (bytes_transferred >= 0) { + EXPECT_LE(bytes_transferred, n); + out->resize(bytes_transferred, n); + } return status.status; } TEST(RamFileBlockCacheTest, IsCacheEnabled) { auto fetcher = [](const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { // Do nothing. - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return 0; }; tf_gcs_filesystem::RamFileBlockCache cache1(0, 0, 0, fetcher); tf_gcs_filesystem::RamFileBlockCache cache2(16, 0, 0, fetcher); @@ -62,12 +64,11 @@ TEST(RamFileBlockCacheTest, IsCacheEnabled) { TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); - *bytes_transferred = n; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return n; }; string filename = "file"; tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher); @@ -96,15 +97,14 @@ TEST(RamFileBlockCacheTest, PassThrough) { int calls = 0; auto fetcher = [&calls, want_filename, want_offset, want_n]( const string& got_filename, size_t got_offset, - size_t got_n, char* buffer, size_t* bytes_transferred, - TF_Status* status) { + size_t got_n, char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(got_filename, want_filename); EXPECT_EQ(got_offset, want_offset); EXPECT_EQ(got_n, want_n); calls++; memset(buffer, 'x', got_n); - *bytes_transferred = got_n; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return got_n; }; // If block_size, max_bytes, or both are zero, or want_n is larger than // max_bytes the cache is a pass-through. @@ -133,16 +133,17 @@ TEST(RamFileBlockCacheTest, BlockAlignment) { } // The fetcher just fetches slices of the buffer. auto fetcher = [&buf](const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { + int64_t bytes_transferred; if (offset < buf.size()) { size_t bytes_to_copy = std::min(buf.size() - offset, n); memcpy(buffer, buf.data() + offset, bytes_to_copy); - *bytes_transferred = bytes_to_copy; + bytes_transferred = bytes_to_copy; } else { - *bytes_transferred = 0; + bytes_transferred = 0; } - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return bytes_transferred; }; for (size_t block_size = 2; block_size <= 4; block_size++) { // Make a cache of N-byte block size (1 block) and verify that reads of @@ -181,15 +182,14 @@ TEST(RamFileBlockCacheTest, CacheHits) { std::set calls; auto fetcher = [&calls, block_size](const string& filename, size_t offset, size_t n, char* buffer, - size_t* bytes_transferred, - TF_Status* status) { + TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset; calls.insert(offset); memset(buffer, 'x', n); - *bytes_transferred = n; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return n; }; const uint32 block_count = 256; tf_gcs_filesystem::RamFileBlockCache cache( @@ -215,8 +215,7 @@ TEST(RamFileBlockCacheTest, OutOfRange) { bool second_block = false; auto fetcher = [block_size, file_size, &first_block, &second_block]( const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); size_t bytes_to_copy = 0; @@ -231,8 +230,8 @@ TEST(RamFileBlockCacheTest, OutOfRange) { memset(buffer, 'x', bytes_to_copy); second_block = true; } - *bytes_transferred = bytes_to_copy; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return bytes_to_copy; }; tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0, fetcher); @@ -260,14 +259,13 @@ TEST(RamFileBlockCacheTest, Inconsistent) { const size_t block_size = 16; // This fetcher returns OK but only fills in one byte for any offset. auto fetcher = [block_size](const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); EXPECT_GE(n, 1); memset(buffer, 'x', 1); - *bytes_transferred = 1; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return 1; }; tf_gcs_filesystem::RamFileBlockCache cache(block_size, 2 * block_size, 0, fetcher); @@ -286,8 +284,7 @@ TEST(RamFileBlockCacheTest, LRU) { std::list calls; auto fetcher = [&calls, block_size](const string& filename, size_t offset, size_t n, char* buffer, - size_t* bytes_transferred, - TF_Status* status) { + TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_FALSE(calls.empty()) << "at offset = " << offset; if (!calls.empty()) { @@ -295,8 +292,8 @@ TEST(RamFileBlockCacheTest, LRU) { calls.pop_front(); } memset(buffer, 'x', n); - *bytes_transferred = n; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return n; }; const uint32 block_count = 2; tf_gcs_filesystem::RamFileBlockCache cache( @@ -335,12 +332,11 @@ TEST(RamFileBlockCacheTest, LRU) { TEST(RamFileBlockCacheTest, MaxStaleness) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); - *bytes_transferred = n; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return n; }; std::vector out; std::unique_ptr env(new NowSecondsEnv); @@ -380,8 +376,7 @@ TEST(RamFileBlockCacheTest, MaxStaleness) { TEST(RamFileBlockCacheTest, RemoveFile) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { calls++; char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x'; if (offset > 0) { @@ -389,8 +384,8 @@ TEST(RamFileBlockCacheTest, RemoveFile) { c = toupper(c); } memset(buffer, c, n); - *bytes_transferred = n; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return n; }; // This cache has space for 4 blocks; we'll read from two files. const size_t n = 3; @@ -443,12 +438,11 @@ TEST(RamFileBlockCacheTest, RemoveFile) { TEST(RamFileBlockCacheTest, Prune) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); - *bytes_transferred = n; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return n; }; std::vector out; // Our fake environment is initialized with the current timestamp. @@ -509,17 +503,17 @@ TEST(RamFileBlockCacheTest, ParallelReads) { const int callers = 4; BlockingCounter counter(callers); auto fetcher = [&counter](const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { counter.DecrementCount(); if (!counter.WaitFor(std::chrono::seconds(10))) { // This avoids having the test time out, which is harder to debug. - return TF_SetStatus(status, TF_FAILED_PRECONDITION, - "desired concurrency not reached"); + TF_SetStatus(status, TF_FAILED_PRECONDITION, + "desired concurrency not reached"); + return -1; } memset(buffer, 'x', n); - *bytes_transferred = n; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return n; }; const int block_size = 8; tf_gcs_filesystem::RamFileBlockCache cache( @@ -548,17 +542,16 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) { Notification notification; auto fetcher = [&num_requests, ¬ification, block_size]( const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { EXPECT_EQ(n, block_size); EXPECT_EQ(offset, 0); num_requests++; memset(buffer, 'x', n); - *bytes_transferred = n; notification.Notify(); // Wait for other thread to issue read. Env::Default()->SleepForMicroseconds(100000); // 0.1 secs - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return n; }; tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0, fetcher); @@ -580,12 +573,11 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) { TEST(RamFileBlockCacheTest, Flush) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred, - TF_Status* status) { + char* buffer, TF_Status* status) -> int64_t { calls++; memset(buffer, 'x', n); - *bytes_transferred = n; - return TF_SetStatus(status, TF_OK, ""); + TF_SetStatus(status, TF_OK, ""); + return n; }; tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher); std::vector out; diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD b/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD new file mode 100644 index 00000000000..51ffd709f3d --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/BUILD @@ -0,0 +1,35 @@ +# Experimental hadoop filesystem plugin. +load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object") + +package( + licenses = ["notice"], # Apache 2.0 +) + +# Filesystem implementation for HADOOP environments +tf_cc_shared_object( + name = "hadoop_filesystem", + framework_so = [], + linkstatic = False, + per_os_targets = 1, + visibility = ["//visibility:public"], + deps = [":hadoop_filesystem_impl"], +) + +# The real implementation of the filesystem. +cc_library( + name = "hadoop_filesystem_impl", + srcs = ["hadoop_filesystem.cc"], + hdrs = ["hadoop_filesystem.h"], + copts = select({ + "//conditions:default": [], + "//tensorflow:windows": get_win_copts(), + }), + deps = [ + "//tensorflow/c:env", + "//tensorflow/c:tf_status", + "//tensorflow/c/experimental/filesystem:filesystem_interface", + "//third_party/hadoop:hdfs", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + ], +) diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc new file mode 100644 index 00000000000..e53e3d0bcc5 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.cc @@ -0,0 +1,660 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h" + +#include +#include + +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/c/env.h" +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/tf_status.h" +#include "third_party/hadoop/hdfs.h" + +// Implementation of a filesystem for HADOOP environments. +// This filesystem will support `hdfs://`, `viewfs://` and `har://` URI schemes. + +static void* plugin_memory_allocate(size_t size) { return calloc(1, size); } +static void plugin_memory_free(void* ptr) { free(ptr); } + +void ParseHadoopPath(const std::string& fname, std::string* scheme, + std::string* namenode, std::string* path) { + size_t scheme_end = fname.find("://") + 2; + *scheme = fname.substr(0, scheme_end + 1); + size_t nn_end = fname.find("/", scheme_end + 1); + if (nn_end == std::string::npos) return; + *namenode = fname.substr(scheme_end + 1, nn_end - scheme_end - 1); + *path = fname.substr(nn_end + 1); +} + +void SplitArchiveNameAndPath(std::string* path, std::string* nn, + TF_Status* status) { + size_t index_end_archive_name = path->find(".har"); + if (index_end_archive_name == path->npos) { + return TF_SetStatus( + status, TF_INVALID_ARGUMENT, + "Hadoop archive path does not contain a .har extension"); + } + // Case of hadoop archive. Namenode is the path to the archive. + std::ostringstream namenodestream; + namenodestream << "har://" << nn + << path->substr(0, index_end_archive_name + 4); + *nn = namenodestream.str(); + path->erase(0, index_end_archive_name + 4); + if (path->empty()) + // Root of the archive + *path = "/"; + return TF_SetStatus(status, TF_OK, ""); +} + +template +void BindFunc(void* handle, const char* name, std::function* func, + TF_Status* status) { + *func = reinterpret_cast( + TF_GetSymbolFromLibrary(handle, name, status)); +} + +class LibHDFS { + public: + explicit LibHDFS(TF_Status* status) { LoadAndBind(status); } + + std::function hdfsBuilderConnect; + std::function hdfsNewBuilder; + std::function hdfsBuilderSetNameNode; + std::function hdfsConfGetStr; + std::function hdfsCloseFile; + std::function hdfsPread; + std::function hdfsWrite; + std::function hdfsHFlush; + std::function hdfsHSync; + std::function hdfsTell; + std::function + hdfsOpenFile; + std::function hdfsExists; + std::function hdfsListDirectory; + std::function hdfsFreeFileInfo; + std::function hdfsDelete; + std::function hdfsCreateDirectory; + std::function hdfsGetPathInfo; + std::function hdfsRename; + + private: + void LoadAndBind(TF_Status* status) { + auto TryLoadAndBind = [this](const char* name, void** handle, + TF_Status* status) { + *handle = TF_LoadSharedLibrary(name, status); + if (TF_GetCode(status) != TF_OK) return; + +#define BIND_HDFS_FUNC(function) \ + do { \ + BindFunc(*handle, #function, &function, status); \ + if (TF_GetCode(status) != TF_OK) return; \ + } while (0); + + BIND_HDFS_FUNC(hdfsBuilderConnect); + BIND_HDFS_FUNC(hdfsNewBuilder); + BIND_HDFS_FUNC(hdfsBuilderSetNameNode); + BIND_HDFS_FUNC(hdfsConfGetStr); + BIND_HDFS_FUNC(hdfsCloseFile); + BIND_HDFS_FUNC(hdfsPread); + BIND_HDFS_FUNC(hdfsWrite); + BIND_HDFS_FUNC(hdfsHFlush); + BIND_HDFS_FUNC(hdfsTell); + BIND_HDFS_FUNC(hdfsHSync); + BIND_HDFS_FUNC(hdfsOpenFile); + BIND_HDFS_FUNC(hdfsExists); + BIND_HDFS_FUNC(hdfsListDirectory); + BIND_HDFS_FUNC(hdfsFreeFileInfo); + BIND_HDFS_FUNC(hdfsDelete); + BIND_HDFS_FUNC(hdfsCreateDirectory); + BIND_HDFS_FUNC(hdfsGetPathInfo); + BIND_HDFS_FUNC(hdfsRename); + +#undef BIND_HDFS_FUNC + }; + + // libhdfs.so won't be in the standard locations. Use the path as specified + // in the libhdfs documentation. +#if defined(_WIN32) + constexpr char kLibHdfsDso[] = "hdfs.dll"; +#elif defined(__GNUC__) && (defined(__APPLE_CPP__) || defined(__APPLE_CC__) || \ + defined(__MACOS_CLASSIC__)) + constexpr char kLibHdfsDso[] = "libhdfs.dylib"; +#else + constexpr char kLibHdfsDso[] = "libhdfs.so"; +#endif + char* hdfs_home = getenv("HADOOP_HDFS_HOME"); + if (hdfs_home != nullptr) { + auto JoinPath = [](std::string home, std::string lib) { + if (home.back() != '/') home.push_back('/'); + return home + "lib/native/" + lib; + }; + std::string path = JoinPath(hdfs_home, kLibHdfsDso); + TryLoadAndBind(path.c_str(), &handle_, status); + if (TF_GetCode(status) == TF_OK) { + return; + } else { + std::cerr << "HadoopFileSystem load error: " << TF_Message(status); + } + } + + // Try to load the library dynamically in case it has been installed + // to a in non-standard location. + TryLoadAndBind(kLibHdfsDso, &handle_, status); + } + + void* handle_; +}; + +// We rely on HDFS connection caching here. The HDFS client calls +// org.apache.hadoop.fs.FileSystem.get(), which caches the connection +// internally. +hdfsFS Connect(LibHDFS* libhdfs, const std::string& path, TF_Status* status) { + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + hdfsBuilder* builder = libhdfs->hdfsNewBuilder(); + if (scheme == "file") { + libhdfs->hdfsBuilderSetNameNode(builder, nullptr); + } else if (scheme == "viewfs") { + char* defaultFS = nullptr; + libhdfs->hdfsConfGetStr("fs.defaultFS", &defaultFS); + std::string defaultScheme, defaultCluster, defaultPath; + ParseHadoopPath(defaultFS, &defaultScheme, &defaultCluster, &defaultPath); + + if (scheme != defaultScheme || + (namenode.empty() && namenode != defaultCluster)) { + TF_SetStatus(status, TF_UNIMPLEMENTED, + "viewfs is only supported as a fs.defaultFS."); + return nullptr; + } + // The default NameNode configuration will be used (from the XML + // configuration files). See: + // https://github.com/tensorflow/tensorflow/blob/v1.0.0/third_party/hadoop/hdfs.h#L259 + libhdfs->hdfsBuilderSetNameNode(builder, "default"); + } else if (scheme == "har") { + std::string path_har = path; + SplitArchiveNameAndPath(&path_har, &namenode, status); + if (TF_GetCode(status) != TF_OK) return nullptr; + libhdfs->hdfsBuilderSetNameNode(builder, namenode.c_str()); + } else { + libhdfs->hdfsBuilderSetNameNode( + builder, namenode.empty() ? "default" : namenode.c_str()); + } + auto fs = libhdfs->hdfsBuilderConnect(builder); + if (fs == nullptr) + TF_SetStatusFromIOError(status, TF_NOT_FOUND, strerror(errno)); + else + TF_SetStatus(status, TF_OK, ""); + return fs; +} + +// SECTION 1. Implementation for `TF_RandomAccessFile` +// ---------------------------------------------------------------------------- +namespace tf_random_access_file { +typedef struct HDFSFile { + std::string path; + std::string hdfs_path; + hdfsFS fs; + LibHDFS* libhdfs; + absl::Mutex mu; + hdfsFile handle ABSL_GUARDED_BY(mu); + HDFSFile(std::string path, std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs, + hdfsFile handle) + : path(std::move(path)), + hdfs_path(std::move(hdfs_path)), + fs(fs), + libhdfs(libhdfs), + mu(), + handle(handle) {} +} HDFSFile; + +void Cleanup(TF_RandomAccessFile* file) { + auto hdfs_file = static_cast(file->plugin_file); + { + absl::MutexLock l(&hdfs_file->mu); + if (hdfs_file->handle != nullptr) { + hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle); + } + } + delete hdfs_file; +} + +int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, + char* buffer, TF_Status* status) { + auto hdfs_file = static_cast(file->plugin_file); + auto libhdfs = hdfs_file->libhdfs; + auto fs = hdfs_file->fs; + auto hdfs_path = hdfs_file->hdfs_path.c_str(); + auto path = hdfs_file->path.c_str(); + + char* dst = buffer; + bool eof_retried = false; + int64_t r = 0; + while (TF_GetCode(status) == TF_OK && !eof_retried) { + // We lock inside the loop rather than outside so we don't block other + // concurrent readers. + absl::MutexLock l(&hdfs_file->mu); + auto handle = hdfs_file->handle; + // Max read length is INT_MAX-2, for hdfsPread function take a parameter + // of int32. -2 offset can avoid JVM OutOfMemoryError. + size_t read_n = + (std::min)(n, static_cast(std::numeric_limits::max() - 2)); + r = libhdfs->hdfsPread(fs, handle, static_cast(offset), dst, + static_cast(read_n)); + if (r > 0) { + dst += r; + n -= r; + offset += r; + } else if (!eof_retried && r == 0) { + // Always reopen the file upon reaching EOF to see if there's more data. + // If writers are streaming contents while others are concurrently + // reading, HDFS requires that we reopen the file to see updated + // contents. + // + // Fixes #5438 + if (handle != nullptr && libhdfs->hdfsCloseFile(fs, handle) != 0) { + TF_SetStatusFromIOError(status, errno, path); + return -1; + } + handle = libhdfs->hdfsOpenFile(fs, hdfs_path, O_RDONLY, 0, 0, 0); + if (handle == nullptr) { + TF_SetStatusFromIOError(status, errno, path); + return -1; + } + eof_retried = true; + } else if (eof_retried && r == 0) { + TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); + } else if (errno == EINTR || errno == EAGAIN) { + // hdfsPread may return EINTR too. Just retry. + } else { + TF_SetStatusFromIOError(status, errno, path); + } + } + return r; +} + +} // namespace tf_random_access_file + +// SECTION 2. Implementation for `TF_WritableFile` +// ---------------------------------------------------------------------------- +namespace tf_writable_file { +typedef struct HDFSFile { + std::string hdfs_path; + hdfsFS fs; + LibHDFS* libhdfs; + hdfsFile handle; + HDFSFile(std::string hdfs_path, hdfsFS fs, LibHDFS* libhdfs, hdfsFile handle) + : hdfs_path(std::move(hdfs_path)), + fs(fs), + libhdfs(libhdfs), + handle(handle) {} +} HDFSFile; + +static void Cleanup(TF_WritableFile* file) { + auto hdfs_file = static_cast(file->plugin_file); + hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle); + hdfs_file->fs = nullptr; + hdfs_file->handle = nullptr; + delete hdfs_file; +} + +void Append(const TF_WritableFile* file, const char* buffer, size_t n, + TF_Status* status) { + auto hdfs_file = static_cast(file->plugin_file); + auto libhdfs = hdfs_file->libhdfs; + auto fs = hdfs_file->fs; + auto handle = hdfs_file->handle; + + size_t cur_pos = 0, write_len = 0; + bool retry = false; + // max() - 2 can avoid OutOfMemoryError in JVM . + static const size_t max_len_once = + static_cast(std::numeric_limits::max() - 2); + while (cur_pos < n) { + write_len = (std::min)(n - cur_pos, max_len_once); + tSize w = libhdfs->hdfsWrite(fs, handle, buffer + cur_pos, + static_cast(write_len)); + if (w == -1) { + if (!retry && (errno == EINTR || errno == EAGAIN)) { + retry = true; + } else { + return TF_SetStatusFromIOError(status, errno, + hdfs_file->hdfs_path.c_str()); + } + } else { + cur_pos += w; + } + } + TF_SetStatus(status, TF_OK, ""); +} + +int64_t Tell(const TF_WritableFile* file, TF_Status* status) { + auto hdfs_file = static_cast(file->plugin_file); + int64_t position = + hdfs_file->libhdfs->hdfsTell(hdfs_file->fs, hdfs_file->handle); + if (position == -1) + TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str()); + else + TF_SetStatus(status, TF_OK, ""); + return position; +} + +void Flush(const TF_WritableFile* file, TF_Status* status) { + auto hdfs_file = static_cast(file->plugin_file); + if (hdfs_file->libhdfs->hdfsHFlush(hdfs_file->fs, hdfs_file->handle) != 0) + TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str()); + else + TF_SetStatus(status, TF_OK, ""); +} + +void Sync(const TF_WritableFile* file, TF_Status* status) { + auto hdfs_file = static_cast(file->plugin_file); + if (hdfs_file->libhdfs->hdfsHSync(hdfs_file->fs, hdfs_file->handle) != 0) + TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str()); + else + TF_SetStatus(status, TF_OK, ""); +} + +void Close(const TF_WritableFile* file, TF_Status* status) { + auto hdfs_file = static_cast(file->plugin_file); + TF_SetStatus(status, TF_OK, ""); + if (hdfs_file->libhdfs->hdfsCloseFile(hdfs_file->fs, hdfs_file->handle) != 0) + TF_SetStatusFromIOError(status, errno, hdfs_file->hdfs_path.c_str()); + hdfs_file->fs = nullptr; + hdfs_file->handle = nullptr; +} + +} // namespace tf_writable_file + +// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion` +// ---------------------------------------------------------------------------- +namespace tf_read_only_memory_region { + +// TODO(vnvo2409): Implement later + +} // namespace tf_read_only_memory_region + +// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem +// ---------------------------------------------------------------------------- +namespace tf_hadoop_filesystem { + +void Init(TF_Filesystem* filesystem, TF_Status* status) { + filesystem->plugin_filesystem = new LibHDFS(status); + if (TF_GetCode(status) != TF_OK) return; + TF_SetStatus(status, TF_OK, ""); +} + +void Cleanup(TF_Filesystem* filesystem) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + delete libhdfs; +} + +void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, + TF_RandomAccessFile* file, TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(), O_RDONLY, 0, 0, 0); + if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path); + + file->plugin_file = + new tf_random_access_file::HDFSFile(path, hdfs_path, fs, libhdfs, handle); + TF_SetStatus(status, TF_OK, ""); +} + +void NewWritableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + auto handle = libhdfs->hdfsOpenFile(fs, hdfs_path.c_str(), + O_WRONLY | O_APPEND, 0, 0, 0); + if (handle == nullptr) return TF_SetStatusFromIOError(status, errno, path); + + file->plugin_file = + new tf_writable_file::HDFSFile(hdfs_path, fs, libhdfs, handle); + TF_SetStatus(status, TF_OK, ""); +} + +void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, + const char* path, + TF_ReadOnlyMemoryRegion* region, + TF_Status* status) { + // hadoopReadZero() technically supports this call with the following + // caveats: + // - It only works up to 2 GB. We'd have to Stat() the file to ensure that + // it fits. + // - If not on the local filesystem, the entire file will be read, making + // it inefficient for callers that assume typical mmap() behavior. + TF_SetStatus(status, TF_UNIMPLEMENTED, + "HDFS does not support ReadOnlyMemoryRegion"); +} + +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + if (libhdfs->hdfsExists(fs, hdfs_path.c_str()) == 0) + TF_SetStatus(status, TF_OK, ""); + else + TF_SetStatus(status, TF_NOT_FOUND, + (std::string(path) + " not found").c_str()); +} + +void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str()); + if (info == nullptr) return TF_SetStatusFromIOError(status, errno, path); + + stats->length = static_cast(info->mSize); + stats->mtime_nsec = static_cast(info->mLastMod) * 1e9; + stats->is_directory = info->mKind == kObjectKindDirectory; + libhdfs->hdfsFreeFileInfo(info, 1); + TF_SetStatus(status, TF_OK, ""); +} + +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return -1; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str()); + if (info == nullptr) { + TF_SetStatusFromIOError(status, errno, path); + return -1; + } + + TF_SetStatus(status, TF_OK, ""); + auto size = static_cast(info->mSize); + libhdfs->hdfsFreeFileInfo(info, 1); + return size; +} + +void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/0) != 0) + TF_SetStatusFromIOError(status, errno, path); + else + TF_SetStatus(status, TF_OK, ""); +} + +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + if (libhdfs->hdfsCreateDirectory(fs, hdfs_path.c_str()) != 0) + TF_SetStatusFromIOError(status, errno, path); + else + TF_SetStatus(status, TF_OK, ""); +} + +void DeleteDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + // Count the number of entries in the directory, and only delete if it's + // non-empty. This is consistent with the interface, but note that there's + // a race condition where a file may be added after this check, in which + // case the directory will still be deleted. + int entries = 0; + auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &entries); + if (info != nullptr) libhdfs->hdfsFreeFileInfo(info, entries); + + // Due to HDFS bug HDFS-8407, we can't distinguish between an error and empty + // folder, especially for Kerberos enable setup, EAGAIN is quite common when + // the call is actually successful. Check again by Stat. + if (info == nullptr && errno != 0) { + TF_FileStatistics stat; + Stat(filesystem, path, &stat, status); + if (TF_GetCode(status) != TF_OK) return; + } + + if (entries > 0) + return TF_SetStatus(status, TF_FAILED_PRECONDITION, + "Cannot delete a non-empty directory."); + + if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/1) != 0) + TF_SetStatusFromIOError(status, errno, path); + else + TF_SetStatus(status, TF_OK, ""); +} + +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, src, status); + if (TF_GetCode(status) != TF_OK) return; + + std::string scheme, namenode, hdfs_path_src, hdfs_path_dst; + ParseHadoopPath(src, &scheme, &namenode, &hdfs_path_src); + ParseHadoopPath(dst, &scheme, &namenode, &hdfs_path_dst); + + if (libhdfs->hdfsExists(fs, hdfs_path_dst.c_str()) == 0 && + libhdfs->hdfsDelete(fs, hdfs_path_dst.c_str(), /*recursive=*/0) != 0) + return TF_SetStatusFromIOError(status, errno, dst); + + if (libhdfs->hdfsRename(fs, hdfs_path_src.c_str(), hdfs_path_dst.c_str()) != + 0) + TF_SetStatusFromIOError(status, errno, src); + else + TF_SetStatus(status, TF_OK, ""); +} + +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status) { + auto libhdfs = static_cast(filesystem->plugin_filesystem); + auto fs = Connect(libhdfs, path, status); + if (TF_GetCode(status) != TF_OK) return -1; + + std::string scheme, namenode, hdfs_path; + ParseHadoopPath(path, &scheme, &namenode, &hdfs_path); + + // hdfsListDirectory returns nullptr if the directory is empty. Do a separate + // check to verify the directory exists first. + TF_FileStatistics stat; + Stat(filesystem, path, &stat, status); + if (TF_GetCode(status) != TF_OK) return -1; + + int num_entries = 0; + auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &num_entries); + if (info == nullptr) { + if (stat.is_directory) { + // Assume it's an empty directory. + TF_SetStatus(status, TF_OK, ""); + return 0; + } + TF_SetStatusFromIOError(status, errno, path); + return -1; + } + *entries = static_cast( + plugin_memory_allocate(num_entries * sizeof((*entries)[0]))); + auto BaseName = [](const std::string& name) { + return name.substr(name.find_last_of('/') + 1); + }; + for (int i = 0; i < num_entries; i++) { + (*entries)[i] = strdup(BaseName(info[i].mName).c_str()); + } + libhdfs->hdfsFreeFileInfo(info, num_entries); + TF_SetStatus(status, TF_OK, ""); + return num_entries; +} + +// TODO(vnvo2409): Implement later + +} // namespace tf_hadoop_filesystem + +static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, + const char* uri) { + TF_SetFilesystemVersionMetadata(ops); + ops->scheme = strdup(uri); +} + +void TF_InitPlugin(TF_FilesystemPluginInfo* info) { + info->plugin_memory_allocate = plugin_memory_allocate; + info->plugin_memory_free = plugin_memory_free; + info->num_schemes = 3; + info->ops = static_cast( + plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0]))); + ProvideFilesystemSupportFor(&info->ops[0], "hdfs"); + ProvideFilesystemSupportFor(&info->ops[1], "viewfs"); + ProvideFilesystemSupportFor(&info->ops[2], "har"); +} diff --git a/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h new file mode 100644 index 00000000000..850cefe0231 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/hadoop/hadoop_filesystem.h @@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ + +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/tf_status.h" + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_HADOOP_HADOOP_FILESYSTEM_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/BUILD b/tensorflow/c/experimental/filesystem/plugins/s3/BUILD new file mode 100644 index 00000000000..56bd3b4a75c --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/BUILD @@ -0,0 +1,63 @@ +# Experimental s3 filesystem plugin. +load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test") + +package( + licenses = ["notice"], # Apache 2.0 +) + +# Filesystem implementation for S3 environments +tf_cc_shared_object( + name = "s3_filesystem", + framework_so = [], + linkstatic = False, + per_os_targets = 1, + visibility = ["//visibility:public"], + deps = [":s3_filesystem_impl"], +) + +# The real implementation of the filesystem. +cc_library( + name = "s3_filesystem_impl", + srcs = ["s3_filesystem.cc"], + hdrs = ["s3_filesystem.h"], + copts = select({ + "//conditions:default": [], + "//tensorflow:windows": get_win_copts(), + }), + deps = [ + ":aws_crypto", + "//tensorflow/c:tf_status", + "//tensorflow/c/experimental/filesystem:filesystem_interface", + "@aws", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + ], +) + +cc_library( + name = "aws_crypto", + srcs = ["aws_crypto.cc"], + hdrs = ["aws_crypto.h"], + deps = [ + "@aws", + "@boringssl//:crypto", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "s3_filesystem_test", + srcs = [ + "s3_filesystem_test.cc", + ], + tags = [ + "manual", + "notap", + ], + deps = [ + ":s3_filesystem_impl", + "//tensorflow/core/platform:path", + "//tensorflow/core/platform:stacktrace_handler", + "//tensorflow/core/platform:test", + ], +) diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.cc b/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.cc new file mode 100644 index 00000000000..2e15ac176e3 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.cc @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h" + +#include +#include +#include +#include +#include + +namespace tf_s3_filesystem { + +class AWSSha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC { + public: + AWSSha256HMACOpenSSLImpl() {} + + virtual ~AWSSha256HMACOpenSSLImpl() = default; + + Aws::Utils::Crypto::HashResult Calculate( + const Aws::Utils::ByteBuffer& toSign, + const Aws::Utils::ByteBuffer& secret) override { + unsigned int length = SHA256_DIGEST_LENGTH; + Aws::Utils::ByteBuffer digest(length); + memset(digest.GetUnderlyingData(), 0, length); + + HMAC_CTX ctx; + HMAC_CTX_init(&ctx); + + HMAC_Init_ex(&ctx, secret.GetUnderlyingData(), + static_cast(secret.GetLength()), EVP_sha256(), NULL); + HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength()); + HMAC_Final(&ctx, digest.GetUnderlyingData(), &length); + HMAC_CTX_cleanup(&ctx); + + return Aws::Utils::Crypto::HashResult(std::move(digest)); + } +}; + +class AWSSha256OpenSSLImpl : public Aws::Utils::Crypto::Hash { + public: + AWSSha256OpenSSLImpl() {} + + virtual ~AWSSha256OpenSSLImpl() = default; + + Aws::Utils::Crypto::HashResult Calculate(const Aws::String& str) override { + SHA256_CTX sha256; + SHA256_Init(&sha256); + SHA256_Update(&sha256, str.data(), str.size()); + + Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH); + SHA256_Final(hash.GetUnderlyingData(), &sha256); + + return Aws::Utils::Crypto::HashResult(std::move(hash)); + } + + Aws::Utils::Crypto::HashResult Calculate(Aws::IStream& stream) override { + SHA256_CTX sha256; + SHA256_Init(&sha256); + + auto currentPos = stream.tellg(); + if (currentPos == std::streampos(std::streamoff(-1))) { + currentPos = 0; + stream.clear(); + } + + stream.seekg(0, stream.beg); + + char streamBuffer + [Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE]; + while (stream.good()) { + stream.read(streamBuffer, + Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE); + auto bytesRead = stream.gcount(); + + if (bytesRead > 0) { + SHA256_Update(&sha256, streamBuffer, static_cast(bytesRead)); + } + } + + stream.clear(); + stream.seekg(currentPos, stream.beg); + + Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH); + SHA256_Final(hash.GetUnderlyingData(), &sha256); + + return Aws::Utils::Crypto::HashResult(std::move(hash)); + } +}; + +class AWSSecureRandomBytesImpl : public Aws::Utils::Crypto::SecureRandomBytes { + public: + AWSSecureRandomBytesImpl() {} + virtual ~AWSSecureRandomBytesImpl() = default; + void GetBytes(unsigned char* buffer, size_t bufferSize) override { + assert(buffer); + int success = RAND_bytes(buffer, static_cast(bufferSize)); + if (success != 1) { + m_failure = true; + } + } + + private: + bool m_failure; +}; + +std::shared_ptr +AWSSHA256Factory::CreateImplementation() const { + return Aws::MakeShared(AWSCryptoAllocationTag); +} + +std::shared_ptr +AWSSHA256HmacFactory::CreateImplementation() const { + return Aws::MakeShared(AWSCryptoAllocationTag); +} + +std::shared_ptr +AWSSecureRandomFactory::CreateImplementation() const { + return Aws::MakeShared(AWSCryptoAllocationTag); +} + +} // namespace tf_s3_filesystem diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h b/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h new file mode 100644 index 00000000000..a70bf060fc7 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_ + +#include +#include +#include +#include +#include + +namespace tf_s3_filesystem { +constexpr char AWSCryptoAllocationTag[] = "AWSCryptoAllocation"; + +class AWSSHA256Factory : public Aws::Utils::Crypto::HashFactory { + public: + std::shared_ptr CreateImplementation() + const override; +}; + +class AWSSHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory { + public: + std::shared_ptr CreateImplementation() + const override; +}; + +class AWSSecureRandomFactory : public Aws::Utils::Crypto::SecureRandomFactory { + public: + std::shared_ptr CreateImplementation() + const override; +}; + +} // namespace tf_s3_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc new file mode 100644 index 00000000000..7e1b36f2dcc --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc @@ -0,0 +1,1239 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h" +#include "tensorflow/c/tf_status.h" + +// Implementation of a filesystem for S3 environments. +// This filesystem will support `s3://` URI schemes. +constexpr char kS3FileSystemAllocationTag[] = "S3FileSystemAllocation"; +constexpr char kS3ClientAllocationTag[] = "S3ClientAllocation"; +constexpr int64_t kS3TimeoutMsec = 300000; // 5 min +constexpr int kS3GetChildrenMaxKeys = 100; + +constexpr char kExecutorTag[] = "TransferManagerExecutorAllocation"; +constexpr int kExecutorPoolSize = 25; + +constexpr uint64_t kS3MultiPartUploadChunkSize = 50 * 1024 * 1024; // 50 MB +constexpr uint64_t kS3MultiPartDownloadChunkSize = 50 * 1024 * 1024; // 50 MB +constexpr size_t kDownloadRetries = 3; +constexpr size_t kUploadRetries = 3; + +constexpr size_t kS3ReadAppendableFileBufferSize = 1024 * 1024; // 1 MB + +static void* plugin_memory_allocate(size_t size) { return calloc(1, size); } +static void plugin_memory_free(void* ptr) { free(ptr); } + +static inline void TF_SetStatusFromAWSError( + const Aws::Client::AWSError& error, TF_Status* status) { + switch (error.GetResponseCode()) { + case Aws::Http::HttpResponseCode::FORBIDDEN: + TF_SetStatus(status, TF_FAILED_PRECONDITION, + "AWS Credentials have not been set properly. " + "Unable to access the specified S3 location"); + break; + case Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE: + TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); + break; + case Aws::Http::HttpResponseCode::NOT_FOUND: + TF_SetStatus(status, TF_NOT_FOUND, error.GetMessage().c_str()); + break; + default: + TF_SetStatus( + status, TF_UNKNOWN, + (error.GetExceptionName() + ": " + error.GetMessage()).c_str()); + break; + } +} + +void ParseS3Path(const Aws::String& fname, bool object_empty_ok, + Aws::String* bucket, Aws::String* object, TF_Status* status) { + size_t scheme_end = fname.find("://") + 2; + if (fname.substr(0, scheme_end + 1) != "s3://") { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "S3 path doesn't start with 's3://'."); + return; + } + + size_t bucket_end = fname.find("/", scheme_end + 1); + if (bucket_end == std::string::npos) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "S3 path doesn't contain a bucket name."); + return; + } + + *bucket = fname.substr(scheme_end + 1, bucket_end - scheme_end - 1); + *object = fname.substr(bucket_end + 1); + + if (object->empty() && !object_empty_ok) { + TF_SetStatus(status, TF_INVALID_ARGUMENT, + "S3 path doesn't contain an object name."); + } +} + +static Aws::Client::ClientConfiguration& GetDefaultClientConfig() { + ABSL_CONST_INIT static absl::Mutex cfg_lock(absl::kConstInit); + static bool init(false); + static Aws::Client::ClientConfiguration cfg; + + absl::MutexLock l(&cfg_lock); + + if (!init) { + const char* endpoint = getenv("S3_ENDPOINT"); + if (endpoint) cfg.endpointOverride = Aws::String(endpoint); + const char* region = getenv("AWS_REGION"); + // TODO (yongtang): `S3_REGION` should be deprecated after 2.0. + if (!region) region = getenv("S3_REGION"); + if (region) { + cfg.region = Aws::String(region); + } else { + // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG + // is set with a truthy value. + const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG"); + std::string load_config = + load_config_env ? absl::AsciiStrToLower(load_config_env) : ""; + if (load_config == "true" || load_config == "1") { + Aws::String config_file; + // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config. + const char* config_file_env = getenv("AWS_CONFIG_FILE"); + if (config_file_env) { + config_file = config_file_env; + } else { + const char* home_env = getenv("HOME"); + if (home_env) { + config_file = home_env; + config_file += "/.aws/config"; + } + } + Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file); + loader.Load(); + auto profiles = loader.GetProfiles(); + if (!profiles["default"].GetRegion().empty()) + cfg.region = profiles["default"].GetRegion(); + } + } + const char* use_https = getenv("S3_USE_HTTPS"); + if (use_https) { + if (use_https[0] == '0') + cfg.scheme = Aws::Http::Scheme::HTTP; + else + cfg.scheme = Aws::Http::Scheme::HTTPS; + } + const char* verify_ssl = getenv("S3_VERIFY_SSL"); + if (verify_ssl) { + if (verify_ssl[0] == '0') + cfg.verifySSL = false; + else + cfg.verifySSL = true; + } + // if these timeouts are low, you may see an error when + // uploading/downloading large files: Unable to connect to endpoint + int64_t timeout; + cfg.connectTimeoutMs = + absl::SimpleAtoi(getenv("S3_CONNECT_TIMEOUT_MSEC"), &timeout) + ? timeout + : kS3TimeoutMsec; + cfg.requestTimeoutMs = + absl::SimpleAtoi(getenv("S3_REQUEST_TIMEOUT_MSEC"), &timeout) + ? timeout + : kS3TimeoutMsec; + const char* ca_file = getenv("S3_CA_FILE"); + if (ca_file) cfg.caFile = Aws::String(ca_file); + const char* ca_path = getenv("S3_CA_PATH"); + if (ca_path) cfg.caPath = Aws::String(ca_path); + init = true; + } + return cfg; +}; + +static void GetS3Client(tf_s3_filesystem::S3File* s3_file) { + absl::MutexLock l(&s3_file->initialization_lock); + + if (s3_file->s3_client.get() == nullptr) { + Aws::SDKOptions options; + options.cryptoOptions.sha256Factory_create_fn = []() { + return Aws::MakeShared( + tf_s3_filesystem::AWSCryptoAllocationTag); + }; + options.cryptoOptions.sha256HMACFactory_create_fn = []() { + return Aws::MakeShared( + tf_s3_filesystem::AWSCryptoAllocationTag); + }; + options.cryptoOptions.secureRandomFactory_create_fn = []() { + return Aws::MakeShared( + tf_s3_filesystem::AWSCryptoAllocationTag); + }; + Aws::InitAPI(options); + + // The creation of S3Client disables virtual addressing: + // S3Client(clientConfiguration, signPayloads, useVirtualAddressing = + // true) + // The purpose is to address the issue encountered when there is an `.` + // in the bucket name. Due to TLS hostname validation or DNS rules, + // the bucket may not be resolved. Disabling of virtual addressing + // should address the issue. See GitHub issue 16397 for details. + s3_file->s3_client = Aws::MakeShared( + kS3ClientAllocationTag, GetDefaultClientConfig(), + Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, false); + } +} + +static void GetExecutor(tf_s3_filesystem::S3File* s3_file) { + absl::MutexLock l(&s3_file->initialization_lock); + + if (s3_file->executor.get() == nullptr) { + s3_file->executor = + Aws::MakeShared( + kExecutorTag, kExecutorPoolSize); + } +} + +static void GetTransferManager( + const Aws::Transfer::TransferDirection& direction, + tf_s3_filesystem::S3File* s3_file) { + // These functions should be called before holding `initialization_lock`. + GetS3Client(s3_file); + GetExecutor(s3_file); + + absl::MutexLock l(&s3_file->initialization_lock); + + if (s3_file->transfer_managers[direction].get() == nullptr) { + Aws::Transfer::TransferManagerConfiguration config(s3_file->executor.get()); + config.s3Client = s3_file->s3_client; + config.bufferSize = s3_file->multi_part_chunk_sizes[direction]; + // must be larger than pool size * multi part chunk size + config.transferBufferMaxHeapSize = + (kExecutorPoolSize + 1) * s3_file->multi_part_chunk_sizes[direction]; + s3_file->transfer_managers[direction] = + Aws::Transfer::TransferManager::Create(config); + } +} + +static void ShutdownClient(Aws::S3::S3Client* s3_client) { + if (s3_client != nullptr) { + delete s3_client; + Aws::SDKOptions options; + Aws::ShutdownAPI(options); + } +} + +// SECTION 1. Implementation for `TF_RandomAccessFile` +// ---------------------------------------------------------------------------- +namespace tf_random_access_file { +typedef struct S3File { + Aws::String bucket; + Aws::String object; + std::shared_ptr s3_client; + std::shared_ptr transfer_manager; + bool use_multi_part_download; +} S3File; + +// AWS Streams destroy the buffer (buf) passed, so creating a new +// IOStream that retains the buffer so the calling function +// can control it's lifecycle +class TFS3UnderlyingStream : public Aws::IOStream { + public: + using Base = Aws::IOStream; + TFS3UnderlyingStream(std::streambuf* buf) : Base(buf) {} + virtual ~TFS3UnderlyingStream() = default; +}; + +void Cleanup(TF_RandomAccessFile* file) { + auto s3_file = static_cast(file->plugin_file); + delete s3_file; +} + +static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n, + char* buffer, TF_Status* status) { + Aws::S3::Model::GetObjectRequest get_object_request; + get_object_request.WithBucket(s3_file->bucket).WithKey(s3_file->object); + Aws::String bytes = + absl::StrCat("bytes=", offset, "-", offset + n - 1).c_str(); + get_object_request.SetRange(bytes); + get_object_request.SetResponseStreamFactory( + []() { return Aws::New(kS3FileSystemAllocationTag); }); + + auto get_object_outcome = s3_file->s3_client->GetObject(get_object_request); + if (!get_object_outcome.IsSuccess()) + TF_SetStatusFromAWSError(get_object_outcome.GetError(), status); + else + TF_SetStatus(status, TF_OK, ""); + if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_OUT_OF_RANGE) + return -1; + + int64_t read = get_object_outcome.GetResult().GetContentLength(); + if (read < n) + TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); + get_object_outcome.GetResult().GetBody().read(buffer, read); + return read; +} + +static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n, + char* buffer, TF_Status* status) { + auto create_download_stream = [&]() { + return Aws::New( + "S3ReadStream", + Aws::New( + "S3ReadStream", reinterpret_cast(buffer), n)); + }; + auto handle = s3_file->transfer_manager->DownloadFile( + s3_file->bucket, s3_file->object, offset, n, create_download_stream); + handle->WaitUntilFinished(); + + size_t retries = 0; + while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED && + handle->GetLastError().GetResponseCode() != + Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE && + retries++ < kDownloadRetries) { + // Only failed parts will be downloaded again. + s3_file->transfer_manager->RetryDownload(handle); + handle->WaitUntilFinished(); + } + + if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED) + TF_SetStatusFromAWSError(handle->GetLastError(), status); + else + TF_SetStatus(status, TF_OK, ""); + if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_OUT_OF_RANGE) + return -1; + int64_t read = handle->GetBytesTransferred(); + if (read < n) + TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested"); + return read; +} + +int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, + char* buffer, TF_Status* status) { + auto s3_file = static_cast(file->plugin_file); + if (s3_file->use_multi_part_download) + return ReadS3TransferManager(s3_file, offset, n, buffer, status); + else + return ReadS3Client(s3_file, offset, n, buffer, status); +} + +} // namespace tf_random_access_file + +// SECTION 2. Implementation for `TF_WritableFile` +// ---------------------------------------------------------------------------- +namespace tf_writable_file { +typedef struct S3File { + Aws::String bucket; + Aws::String object; + std::shared_ptr s3_client; + std::shared_ptr transfer_manager; + bool sync_needed; + std::shared_ptr outfile; + S3File(Aws::String bucket, Aws::String object, + std::shared_ptr s3_client, + std::shared_ptr transfer_manager) + : bucket(bucket), + object(object), + s3_client(s3_client), + transfer_manager(transfer_manager), + outfile(Aws::MakeShared( + kS3FileSystemAllocationTag, nullptr, "_s3_filesystem_XXXXXX", + std::ios_base::binary | std::ios_base::trunc | std::ios_base::in | + std::ios_base::out)) {} +} S3File; + +void Cleanup(TF_WritableFile* file) { + auto s3_file = static_cast(file->plugin_file); + delete s3_file; +} + +void Append(const TF_WritableFile* file, const char* buffer, size_t n, + TF_Status* status) { + auto s3_file = static_cast(file->plugin_file); + if (!s3_file->outfile) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, + "The internal temporary file is not writable."); + return; + } + s3_file->sync_needed = true; + s3_file->outfile->write(buffer, n); + if (!s3_file->outfile->good()) + TF_SetStatus(status, TF_INTERNAL, + "Could not append to the internal temporary file."); + else + TF_SetStatus(status, TF_OK, ""); +} + +int64_t Tell(const TF_WritableFile* file, TF_Status* status) { + auto s3_file = static_cast(file->plugin_file); + auto position = static_cast(s3_file->outfile->tellp()); + if (position == -1) + TF_SetStatus(status, TF_INTERNAL, + "tellp on the internal temporary file failed"); + else + TF_SetStatus(status, TF_OK, ""); + return position; +} + +void Sync(const TF_WritableFile* file, TF_Status* status) { + auto s3_file = static_cast(file->plugin_file); + if (!s3_file->outfile) { + TF_SetStatus(status, TF_FAILED_PRECONDITION, + "The internal temporary file is not writable."); + return; + } + if (!s3_file->sync_needed) { + TF_SetStatus(status, TF_OK, ""); + return; + } + auto position = static_cast(s3_file->outfile->tellp()); + auto handle = s3_file->transfer_manager->UploadFile( + s3_file->outfile, s3_file->bucket, s3_file->object, + "application/octet-stream", Aws::Map()); + handle->WaitUntilFinished(); + + size_t retries = 0; + while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED && + retries++ < kUploadRetries) { + // if multipart upload was used, only the failed parts will be re-sent + s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle); + handle->WaitUntilFinished(); + } + if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED) + return TF_SetStatusFromAWSError(handle->GetLastError(), status); + s3_file->outfile->clear(); + s3_file->outfile->seekp(position); + s3_file->sync_needed = false; + TF_SetStatus(status, TF_OK, ""); +} + +void Flush(const TF_WritableFile* file, TF_Status* status) { + Sync(file, status); +} + +void Close(const TF_WritableFile* file, TF_Status* status) { + auto s3_file = static_cast(file->plugin_file); + if (s3_file->outfile) { + Sync(file, status); + if (TF_GetCode(status) != TF_OK) return; + s3_file->outfile.reset(); + } + TF_SetStatus(status, TF_OK, ""); +} + +} // namespace tf_writable_file + +// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion` +// ---------------------------------------------------------------------------- +namespace tf_read_only_memory_region { +typedef struct S3MemoryRegion { + std::unique_ptr data; + uint64_t length; +} S3MemoryRegion; + +void Cleanup(TF_ReadOnlyMemoryRegion* region) { + auto r = static_cast(region->plugin_memory_region); + delete r; +} + +const void* Data(const TF_ReadOnlyMemoryRegion* region) { + auto r = static_cast(region->plugin_memory_region); + return reinterpret_cast(r->data.get()); +} + +uint64_t Length(const TF_ReadOnlyMemoryRegion* region) { + auto r = static_cast(region->plugin_memory_region); + return r->length; +} + +} // namespace tf_read_only_memory_region + +// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem +// ---------------------------------------------------------------------------- +namespace tf_s3_filesystem { +S3File::S3File() + : s3_client(nullptr, ShutdownClient), + executor(nullptr), + transfer_managers(), + multi_part_chunk_sizes(), + use_multi_part_download(true), + initialization_lock() { + uint64_t temp_value; + multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD] = + absl::SimpleAtoi(getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE"), &temp_value) + ? temp_value + : kS3MultiPartUploadChunkSize; + multi_part_chunk_sizes[Aws::Transfer::TransferDirection::DOWNLOAD] = + absl::SimpleAtoi(getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE"), &temp_value) + ? temp_value + : kS3MultiPartDownloadChunkSize; + use_multi_part_download = + absl::SimpleAtoi(getenv("S3_DISABLE_MULTI_PART_DOWNLOAD"), &temp_value) + ? (temp_value != 1) + : use_multi_part_download; + transfer_managers.emplace(Aws::Transfer::TransferDirection::UPLOAD, nullptr); + transfer_managers.emplace(Aws::Transfer::TransferDirection::DOWNLOAD, + nullptr); +} +void Init(TF_Filesystem* filesystem, TF_Status* status) { + filesystem->plugin_filesystem = new S3File(); + TF_SetStatus(status, TF_OK, ""); +} + +void Cleanup(TF_Filesystem* filesystem) { + auto s3_file = static_cast(filesystem->plugin_filesystem); + delete s3_file; +} + +void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, + TF_RandomAccessFile* file, TF_Status* status) { + Aws::String bucket, object; + ParseS3Path(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + + auto s3_file = static_cast(filesystem->plugin_filesystem); + GetS3Client(s3_file); + GetTransferManager(Aws::Transfer::TransferDirection::DOWNLOAD, s3_file); + file->plugin_file = new tf_random_access_file::S3File( + {bucket, object, s3_file->s3_client, + s3_file->transfer_managers[Aws::Transfer::TransferDirection::DOWNLOAD], + s3_file->use_multi_part_download}); + TF_SetStatus(status, TF_OK, ""); +} + +void NewWritableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status) { + Aws::String bucket, object; + ParseS3Path(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + + auto s3_file = static_cast(filesystem->plugin_filesystem); + GetS3Client(s3_file); + GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file); + file->plugin_file = new tf_writable_file::S3File( + bucket, object, s3_file->s3_client, + s3_file->transfer_managers[Aws::Transfer::TransferDirection::UPLOAD]); + TF_SetStatus(status, TF_OK, ""); +} + +void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status) { + Aws::String bucket, object; + ParseS3Path(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + + auto s3_file = static_cast(filesystem->plugin_filesystem); + GetS3Client(s3_file); + GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file); + + // We need to delete `file->plugin_file` in case of errors. We set + // `file->plugin_file` to `nullptr` in order to avoid segment fault when + // calling deleter of `unique_ptr`. + file->plugin_file = nullptr; + std::unique_ptr writer( + file, [](TF_WritableFile* file) { + if (file != nullptr && file->plugin_file != nullptr) { + tf_writable_file::Cleanup(file); + } + }); + writer->plugin_file = new tf_writable_file::S3File( + bucket, object, s3_file->s3_client, + s3_file->transfer_managers[Aws::Transfer::TransferDirection::UPLOAD]); + TF_SetStatus(status, TF_OK, ""); + + // Wraping inside a `std::unique_ptr` to prevent memory-leaking. + std::unique_ptr reader( + new TF_RandomAccessFile, [](TF_RandomAccessFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) + tf_random_access_file::Cleanup(file); + delete file; + } + }); + // We set `reader->plugin_file` to `nullptr` in order to avoid segment fault + // when calling deleter of `unique_ptr` + reader->plugin_file = nullptr; + NewRandomAccessFile(filesystem, path, reader.get(), status); + if (TF_GetCode(status) != TF_OK) return; + + uint64_t offset = 0; + std::string buffer(kS3ReadAppendableFileBufferSize, {}); + while (true) { + auto read = tf_random_access_file::Read(reader.get(), offset, + kS3ReadAppendableFileBufferSize, + &buffer[0], status); + if (TF_GetCode(status) == TF_NOT_FOUND) { + break; + } else if (TF_GetCode(status) == TF_OK) { + offset += read; + tf_writable_file::Append(file, buffer.c_str(), read, status); + if (TF_GetCode(status) != TF_OK) return; + } else if (TF_GetCode(status) == TF_OUT_OF_RANGE) { + offset += read; + tf_writable_file::Append(file, buffer.c_str(), read, status); + if (TF_GetCode(status) != TF_OK) return; + break; + } else { + return; + } + } + writer.release(); + TF_SetStatus(status, TF_OK, ""); +} + +void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status) { + Aws::String bucket, object; + ParseS3Path(path, true, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + auto s3_file = static_cast(filesystem->plugin_filesystem); + GetS3Client(s3_file); + + if (object.empty()) { + Aws::S3::Model::HeadBucketRequest head_bucket_request; + head_bucket_request.WithBucket(bucket); + auto head_bucket_outcome = + s3_file->s3_client->HeadBucket(head_bucket_request); + if (!head_bucket_outcome.IsSuccess()) + return TF_SetStatusFromAWSError(head_bucket_outcome.GetError(), status); + stats->length = 0; + stats->is_directory = 1; + stats->mtime_nsec = 0; + return TF_SetStatus(status, TF_OK, ""); + } + + bool found = false; + Aws::S3::Model::HeadObjectRequest head_object_request; + head_object_request.WithBucket(bucket).WithKey(object); + head_object_request.SetResponseStreamFactory( + []() { return Aws::New(kS3FileSystemAllocationTag); }); + auto head_object_outcome = + s3_file->s3_client->HeadObject(head_object_request); + if (head_object_outcome.IsSuccess()) { + stats->length = head_object_outcome.GetResult().GetContentLength(); + stats->is_directory = 0; + stats->mtime_nsec = + head_object_outcome.GetResult().GetLastModified().Millis() * 1e6; + found = true; + } else { + TF_SetStatusFromAWSError(head_object_outcome.GetError(), status); + if (TF_GetCode(status) == TF_FAILED_PRECONDITION) return; + } + + auto prefix = object; + if (prefix.back() != '/') { + prefix.push_back('/'); + } + Aws::S3::Model::ListObjectsRequest list_objects_request; + list_objects_request.WithBucket(bucket).WithPrefix(prefix).WithMaxKeys(1); + list_objects_request.SetResponseStreamFactory( + []() { return Aws::New(kS3FileSystemAllocationTag); }); + auto list_objects_outcome = + s3_file->s3_client->ListObjects(list_objects_request); + if (list_objects_outcome.IsSuccess()) { + auto objects = list_objects_outcome.GetResult().GetContents(); + if (objects.size() > 0) { + stats->length = 0; + stats->is_directory = 1; + stats->mtime_nsec = objects[0].GetLastModified().Millis() * 1e6; + found = true; + } + } else { + TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status); + if (TF_GetCode(status) == TF_FAILED_PRECONDITION) return; + } + if (!found) + return TF_SetStatus( + status, TF_NOT_FOUND, + absl::StrCat("Object ", path, " does not exist").c_str()); + TF_SetStatus(status, TF_OK, ""); +} + +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + TF_FileStatistics stats; + Stat(filesystem, path, &stats, status); +} + +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + TF_FileStatistics stats; + Stat(filesystem, path, &stats, status); + return stats.length; +} + +void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, + const char* path, + TF_ReadOnlyMemoryRegion* region, + TF_Status* status) { + Aws::String bucket, object; + ParseS3Path(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + + auto s3_file = static_cast(filesystem->plugin_filesystem); + GetS3Client(s3_file); + GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file); + + auto size = GetFileSize(filesystem, path, status); + if (TF_GetCode(status) != TF_OK) return; + if (size == 0) + return TF_SetStatus(status, TF_INVALID_ARGUMENT, "File is empty"); + + std::unique_ptr data(new char[size]); + // Wraping inside a `std::unique_ptr` to prevent memory-leaking. + std::unique_ptr reader( + new TF_RandomAccessFile, [](TF_RandomAccessFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) + tf_random_access_file::Cleanup(file); + delete file; + } + }); + // We set `reader->plugin_file` to `nullptr` in order to avoid segment fault + // when calling deleter of `unique_ptr` + reader->plugin_file = nullptr; + NewRandomAccessFile(filesystem, path, reader.get(), status); + if (TF_GetCode(status) != TF_OK) return; + auto read = + tf_random_access_file::Read(reader.get(), 0, size, data.get(), status); + if (TF_GetCode(status) != TF_OK) return; + + region->plugin_memory_region = new tf_read_only_memory_region::S3MemoryRegion( + {std::move(data), static_cast(read)}); + TF_SetStatus(status, TF_OK, ""); +} + +static void SimpleCopyFile(const Aws::String& source, + const Aws::String& bucket_dst, + const Aws::String& object_dst, S3File* s3_file, + TF_Status* status) { + Aws::S3::Model::CopyObjectRequest copy_object_request; + copy_object_request.WithCopySource(source) + .WithBucket(bucket_dst) + .WithKey(object_dst); + auto copy_object_outcome = + s3_file->s3_client->CopyObject(copy_object_request); + if (!copy_object_outcome.IsSuccess()) + TF_SetStatusFromAWSError(copy_object_outcome.GetError(), status); + else + TF_SetStatus(status, TF_OK, ""); +}; + +using EtagOutcome = + Aws::Utils::Outcome>; +typedef struct MultipartCopyAsyncContext + : public Aws::Client::AsyncCallerContext { + int part_number; + int* num_finished_parts; + Aws::Vector* etag_outcomes; + + // lock and cv for multi part copy + absl::Mutex* multi_part_copy_mutex; + absl::CondVar* multi_part_copy_cv; +} MultipartCopyAsyncContext; + +static void AbortMultiPartCopy(const Aws::String& bucket_dst, + const Aws::String& object_dst, + const Aws::String& upload_id, S3File* s3_file, + TF_Status* status) { + Aws::S3::Model::AbortMultipartUploadRequest request; + request.WithBucket(bucket_dst).WithKey(object_dst).WithUploadId(upload_id); + auto outcome = s3_file->s3_client->AbortMultipartUpload(request); + if (!outcome.IsSuccess()) + TF_SetStatusFromAWSError(outcome.GetError(), status); + else + TF_SetStatus(status, TF_OK, ""); +} + +static void MultiPartCopyCallback( + const Aws::S3::Model::UploadPartCopyRequest& request, + const Aws::S3::Model::UploadPartCopyOutcome& outcome, + const std::shared_ptr& context) { + // Access to `etag_outcomes` should be thread-safe because of distinct + // `part_number`. + auto part_number = context->part_number; + auto etag_outcomes = context->etag_outcomes; + if (outcome.IsSuccess()) { + (*etag_outcomes)[part_number] = + outcome.GetResult().GetCopyPartResult().GetETag(); + } else { + (*etag_outcomes)[part_number] = outcome.GetError(); + } + { + absl::MutexLock l(context->multi_part_copy_mutex); + (*context->num_finished_parts)++; + context->multi_part_copy_cv->Signal(); + } +} + +static void MultiPartCopy(const Aws::String& source, + const Aws::String& bucket_dst, + const Aws::String& object_dst, const size_t num_parts, + const uint64_t file_size, S3File* s3_file, + TF_Status* status) { + Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request; + create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst); + + GetS3Client(s3_file); + GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file); + + auto create_multipart_upload_outcome = + s3_file->s3_client->CreateMultipartUpload( + create_multipart_upload_request); + if (!create_multipart_upload_outcome.IsSuccess()) + return TF_SetStatusFromAWSError(create_multipart_upload_outcome.GetError(), + status); + + auto upload_id = create_multipart_upload_outcome.GetResult().GetUploadId(); + + int num_finished_parts = 0; + // Keep track of `Outcome` of each upload part. + Aws::Vector etag_outcomes(num_parts); + // Mutex which protects access of the part_states map. + absl::Mutex multi_part_copy_mutex; + // Condition variable to be used with above mutex for synchronization. + absl::CondVar multi_part_copy_cv; + + auto chunk_size = + s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD]; + + size_t retries = 0; + while (retries++ < 3) { + // Queue up parts. + for (auto part_number = 0; part_number < num_parts; ++part_number) { + if (etag_outcomes[part_number].IsSuccess()) continue; + uint64_t start_pos = part_number * chunk_size; + uint64_t end_pos = start_pos + chunk_size - 1; + if (end_pos >= file_size) end_pos = file_size - 1; + + Aws::String range = + absl::StrCat("bytes=", start_pos, "-", end_pos).c_str(); + Aws::S3::Model::UploadPartCopyRequest upload_part_copy_request; + upload_part_copy_request.WithBucket(bucket_dst) + .WithKey(object_dst) + .WithCopySource(source) + .WithCopySourceRange(range) + // S3 API partNumber starts from 1. + .WithPartNumber(part_number + 1) + .WithUploadId(upload_id); + + auto multi_part_context = + Aws::MakeShared("MultiPartCopyContext"); + multi_part_context->part_number = part_number; + multi_part_context->num_finished_parts = &num_finished_parts; + multi_part_context->etag_outcomes = &etag_outcomes; + multi_part_context->multi_part_copy_mutex = &multi_part_copy_mutex; + multi_part_context->multi_part_copy_cv = &multi_part_copy_cv; + auto callback = + [](const Aws::S3::S3Client* client, + const Aws::S3::Model::UploadPartCopyRequest& request, + const Aws::S3::Model::UploadPartCopyOutcome& outcome, + const std::shared_ptr& + context) { + auto multipart_context = + std::static_pointer_cast( + context); + MultiPartCopyCallback(request, outcome, multipart_context); + }; + + std::shared_ptr context = + multi_part_context; + s3_file->s3_client->UploadPartCopyAsync(upload_part_copy_request, + callback, context); + } + // Wait till they finish. + { + absl::MutexLock l(&multi_part_copy_mutex); + // Wait on the mutex until notify is called then check the finished parts + // as there could be false notifications. + while (num_finished_parts != num_parts) { + multi_part_copy_cv.Wait(&multi_part_copy_mutex); + } + } + // check if there was any error for any part. + for (auto part_number = 0; part_number < num_parts; ++part_number) { + if (!etag_outcomes[part_number].IsSuccess()) { + if (retries >= 3) { + AbortMultiPartCopy(bucket_dst, object_dst, upload_id, s3_file, + status); + if (TF_GetCode(status) != TF_OK) return; + return TF_SetStatusFromAWSError(etag_outcomes[part_number].GetError(), + status); + } else { + // Retry. + num_finished_parts--; + } + } + } + } + + Aws::S3::Model::CompletedMultipartUpload completed_multipart_upload; + // If there was an error still in any part, it would abort and return in the + // above loop. We set the eTag of completed parts to the final + // `completed_multipart_upload`. Note these parts have to be added in order. + for (int part_number = 0; part_number < num_parts; ++part_number) { + Aws::S3::Model::CompletedPart completed_part; + completed_part.SetPartNumber(part_number + 1); + completed_part.SetETag(etag_outcomes[part_number].GetResult()); + completed_multipart_upload.AddParts(completed_part); + } + + Aws::S3::Model::CompleteMultipartUploadRequest + complete_multipart_upload_request; + complete_multipart_upload_request.WithBucket(bucket_dst) + .WithKey(object_dst) + .WithUploadId(upload_id) + .WithMultipartUpload(completed_multipart_upload); + auto complete_multipart_upload_outcome = + s3_file->s3_client->CompleteMultipartUpload( + complete_multipart_upload_request); + if (!complete_multipart_upload_outcome.IsSuccess()) + AbortMultiPartCopy(bucket_dst, object_dst, upload_id, s3_file, status); + else + return TF_SetStatus(status, TF_OK, ""); + if (TF_GetCode(status) == TF_OK) + return TF_SetStatusFromAWSError( + complete_multipart_upload_outcome.GetError(), status); +}; + +void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, + TF_Status* status) { + auto file_size = GetFileSize(filesystem, src, status); + if (TF_GetCode(status) != TF_OK) return; + if (file_size == 0) + return TF_SetStatus(status, TF_FAILED_PRECONDITION, + "Source is a directory or empty file"); + + Aws::String bucket_src, object_src; + ParseS3Path(src, false, &bucket_src, &object_src, status); + if (TF_GetCode(status) != TF_OK) return; + Aws::String copy_src = bucket_src + "/" + object_src; + + Aws::String bucket_dst, object_dst; + ParseS3Path(dst, false, &bucket_dst, &object_dst, status); + if (TF_GetCode(status) != TF_OK) return; + + auto s3_file = static_cast(filesystem->plugin_filesystem); + auto chunk_size = + s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD]; + size_t num_parts = 1; + if (file_size > chunk_size) num_parts = ceil((float)file_size / chunk_size); + if (num_parts == 1) + SimpleCopyFile(copy_src, bucket_dst, object_dst, s3_file, status); + else if (num_parts > 10000) + TF_SetStatus( + status, TF_UNIMPLEMENTED, + absl::StrCat("MultiPartCopy with number of parts more than 10000 is " + "not supported. Your object ", + src, " required ", num_parts, + " as multi_part_copy_part_size is set to ", chunk_size, + ". You can control this part size using the environment " + "variable S3_MULTI_PART_COPY_PART_SIZE to increase it.") + .c_str()); + else + MultiPartCopy(copy_src, bucket_dst, object_dst, num_parts, file_size, + s3_file, status); +} + +void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + Aws::String bucket, object; + ParseS3Path(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + auto s3_file = static_cast(filesystem->plugin_filesystem); + GetS3Client(s3_file); + + Aws::S3::Model::DeleteObjectRequest delete_object_request; + delete_object_request.WithBucket(bucket).WithKey(object); + auto delete_object_outcome = + s3_file->s3_client->DeleteObject(delete_object_request); + if (!delete_object_outcome.IsSuccess()) + TF_SetStatusFromAWSError(delete_object_outcome.GetError(), status); + else + TF_SetStatus(status, TF_OK, ""); +} + +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + Aws::String bucket, object; + ParseS3Path(path, true, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + auto s3_file = static_cast(filesystem->plugin_filesystem); + GetS3Client(s3_file); + + if (object.empty()) { + Aws::S3::Model::HeadBucketRequest head_bucket_request; + head_bucket_request.WithBucket(bucket); + auto head_bucket_outcome = + s3_file->s3_client->HeadBucket(head_bucket_request); + if (!head_bucket_outcome.IsSuccess()) + TF_SetStatusFromAWSError(head_bucket_outcome.GetError(), status); + else + TF_SetStatus(status, TF_OK, ""); + return; + } + + Aws::String dir_path = path; + if (dir_path.back() != '/') dir_path.push_back('/'); + + PathExists(filesystem, dir_path.c_str(), status); + if (TF_GetCode(status) == TF_OK) { + std::unique_ptr file( + new TF_WritableFile, [](TF_WritableFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file); + delete file; + } + }); + file->plugin_file = nullptr; + NewWritableFile(filesystem, dir_path.c_str(), file.get(), status); + if (TF_GetCode(status) != TF_OK) return; + tf_writable_file::Close(file.get(), status); + if (TF_GetCode(status) != TF_OK) return; + } + TF_SetStatus(status, TF_OK, ""); +} + +void DeleteDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status) { + Aws::String bucket, object; + ParseS3Path(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + auto s3_file = static_cast(filesystem->plugin_filesystem); + GetS3Client(s3_file); + + if (object.back() != '/') object.push_back('/'); + Aws::S3::Model::ListObjectsRequest list_objects_request; + list_objects_request.WithBucket(bucket).WithPrefix(object).WithMaxKeys(2); + list_objects_request.SetResponseStreamFactory( + []() { return Aws::New(kS3FileSystemAllocationTag); }); + auto list_objects_outcome = + s3_file->s3_client->ListObjects(list_objects_request); + if (list_objects_outcome.IsSuccess()) { + auto contents = list_objects_outcome.GetResult().GetContents(); + if (contents.size() > 1 || + (contents.size() == 1 && contents[0].GetKey() != object)) { + TF_SetStatus(status, TF_UNKNOWN, + "Cannot delete a non-empty directory. " + "This operation will be retried in case this " + "is due to S3's eventual consistency."); + } + if (contents.size() == 1 && contents[0].GetKey() == object) { + Aws::String dir_path = path; + if (dir_path.back() != '/') dir_path.push_back('/'); + DeleteFile(filesystem, dir_path.c_str(), status); + } + } else { + TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status); + } +} + +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status) { + Aws::String bucket_src, object_src; + ParseS3Path(src, false, &bucket_src, &object_src, status); + if (TF_GetCode(status) != TF_OK) return; + Aws::String copy_src = bucket_src + "/" + object_src; + + Aws::String bucket_dst, object_dst; + ParseS3Path(dst, false, &bucket_dst, &object_dst, status); + if (TF_GetCode(status) != TF_OK) return; + + auto s3_file = static_cast(filesystem->plugin_filesystem); + GetS3Client(s3_file); + + if (object_src.back() == '/') { + if (object_dst.back() != '/') { + object_dst.push_back('/'); + } + } else { + if (object_dst.back() == '/') { + object_dst.pop_back(); + } + } + + Aws::S3::Model::DeleteObjectRequest delete_object_request; + Aws::S3::Model::ListObjectsRequest list_objects_request; + list_objects_request.WithBucket(bucket_src) + .WithPrefix(object_src) + .WithMaxKeys(kS3GetChildrenMaxKeys); + list_objects_request.SetResponseStreamFactory( + []() { return Aws::New(kS3FileSystemAllocationTag); }); + + Aws::S3::Model::ListObjectsResult list_objects_result; + do { + auto list_objects_outcome = + s3_file->s3_client->ListObjects(list_objects_request); + if (!list_objects_outcome.IsSuccess()) + return TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status); + + list_objects_result = list_objects_outcome.GetResult(); + for (const auto& object : list_objects_result.GetContents()) { + Aws::String key_src = object.GetKey(); + Aws::String key_dst = key_src; + key_dst.replace(0, object_src.length(), object_dst); + CopyFile(filesystem, ("s3://" + bucket_src + "/" + key_src).c_str(), + ("s3://" + bucket_dst + "/" + key_dst).c_str(), status); + if (TF_GetCode(status) != TF_OK) return; + + delete_object_request.WithBucket(bucket_src).WithKey(key_src); + auto delete_object_outcome = + s3_file->s3_client->DeleteObject(delete_object_request); + if (!delete_object_outcome.IsSuccess()) + return TF_SetStatusFromAWSError(delete_object_outcome.GetError(), + status); + } + list_objects_request.SetMarker(list_objects_result.GetNextMarker()); + } while (list_objects_result.GetIsTruncated()); + TF_SetStatus(status, TF_OK, ""); +} + +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status) { + Aws::String bucket, prefix; + ParseS3Path(path, true, &bucket, &prefix, status); + if (TF_GetCode(status) != TF_OK) return -1; + if (!prefix.empty() && prefix.back() != '/') prefix.push_back('/'); + + auto s3_file = static_cast(filesystem->plugin_filesystem); + GetS3Client(s3_file); + + Aws::S3::Model::ListObjectsRequest list_objects_request; + list_objects_request.WithBucket(bucket) + .WithPrefix(prefix) + .WithMaxKeys(kS3GetChildrenMaxKeys) + .WithDelimiter("/"); + list_objects_request.SetResponseStreamFactory( + []() { return Aws::New(kS3FileSystemAllocationTag); }); + + Aws::S3::Model::ListObjectsResult list_objects_result; + std::vector result; + do { + auto list_objects_outcome = + s3_file->s3_client->ListObjects(list_objects_request); + if (!list_objects_outcome.IsSuccess()) { + TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status); + return -1; + } + + list_objects_result = list_objects_outcome.GetResult(); + for (const auto& object : list_objects_result.GetCommonPrefixes()) { + Aws::String s = object.GetPrefix(); + s.erase(s.length() - 1); + Aws::String entry = s.substr(prefix.length()); + if (entry.length() > 0) { + result.push_back(entry); + } + } + for (const auto& object : list_objects_result.GetContents()) { + Aws::String s = object.GetKey(); + Aws::String entry = s.substr(prefix.length()); + if (entry.length() > 0) { + result.push_back(entry); + } + } + list_objects_result.SetMarker(list_objects_result.GetNextMarker()); + } while (list_objects_result.GetIsTruncated()); + + int num_entries = result.size(); + *entries = static_cast( + plugin_memory_allocate(num_entries * sizeof((*entries)[0]))); + for (int i = 0; i < num_entries; i++) + (*entries)[i] = strdup(result[i].c_str()); + TF_SetStatus(status, TF_OK, ""); + return num_entries; +} + +static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) { + return strdup(uri); +} + +} // namespace tf_s3_filesystem + +static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, + const char* uri) { + TF_SetFilesystemVersionMetadata(ops); + ops->scheme = strdup(uri); + + ops->random_access_file_ops = static_cast( + plugin_memory_allocate(TF_RANDOM_ACCESS_FILE_OPS_SIZE)); + ops->random_access_file_ops->cleanup = tf_random_access_file::Cleanup; + ops->random_access_file_ops->read = tf_random_access_file::Read; + + ops->writable_file_ops = static_cast( + plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE)); + ops->writable_file_ops->cleanup = tf_writable_file::Cleanup; + ops->writable_file_ops->append = tf_writable_file::Append; + ops->writable_file_ops->tell = tf_writable_file::Tell; + ops->writable_file_ops->flush = tf_writable_file::Flush; + ops->writable_file_ops->sync = tf_writable_file::Sync; + ops->writable_file_ops->close = tf_writable_file::Close; + + ops->read_only_memory_region_ops = static_cast( + plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE)); + ops->read_only_memory_region_ops->cleanup = + tf_read_only_memory_region::Cleanup; + ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data; + ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length; + + ops->filesystem_ops = static_cast( + plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE)); + ops->filesystem_ops->init = tf_s3_filesystem::Init; + ops->filesystem_ops->cleanup = tf_s3_filesystem::Cleanup; + ops->filesystem_ops->new_random_access_file = + tf_s3_filesystem::NewRandomAccessFile; + ops->filesystem_ops->new_writable_file = tf_s3_filesystem::NewWritableFile; + ops->filesystem_ops->new_appendable_file = + tf_s3_filesystem::NewAppendableFile; + ops->filesystem_ops->new_read_only_memory_region_from_file = + tf_s3_filesystem::NewReadOnlyMemoryRegionFromFile; + ops->filesystem_ops->create_dir = tf_s3_filesystem::CreateDir; + ops->filesystem_ops->delete_file = tf_s3_filesystem::DeleteFile; + ops->filesystem_ops->delete_dir = tf_s3_filesystem::DeleteDir; + ops->filesystem_ops->copy_file = tf_s3_filesystem::CopyFile; + ops->filesystem_ops->rename_file = tf_s3_filesystem::RenameFile; + ops->filesystem_ops->path_exists = tf_s3_filesystem::PathExists; + ops->filesystem_ops->get_file_size = tf_s3_filesystem::GetFileSize; + ops->filesystem_ops->stat = tf_s3_filesystem::Stat; + ops->filesystem_ops->get_children = tf_s3_filesystem::GetChildren; + ops->filesystem_ops->translate_name = tf_s3_filesystem::TranslateName; +} + +void TF_InitPlugin(TF_FilesystemPluginInfo* info) { + info->plugin_memory_allocate = plugin_memory_allocate; + info->plugin_memory_free = plugin_memory_free; + info->num_schemes = 1; + info->ops = static_cast( + plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0]))); + ProvideFilesystemSupportFor(&info->ops[0], "s3"); +} diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h new file mode 100644 index 00000000000..4a995e8c109 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h @@ -0,0 +1,101 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/tf_status.h" + +void ParseS3Path(const Aws::String& fname, bool object_empty_ok, + Aws::String* bucket, Aws::String* object, TF_Status* status); + +namespace tf_random_access_file { +void Cleanup(TF_RandomAccessFile* file); +int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, + char* buffer, TF_Status* status); +} // namespace tf_random_access_file + +namespace tf_writable_file { +void Cleanup(TF_WritableFile* file); +void Append(const TF_WritableFile* file, const char* buffer, size_t n, + TF_Status* status); +int64_t Tell(const TF_WritableFile* file, TF_Status* status); +void Sync(const TF_WritableFile* file, TF_Status* status); +void Flush(const TF_WritableFile* file, TF_Status* status); +void Close(const TF_WritableFile* file, TF_Status* status); +} // namespace tf_writable_file + +namespace tf_read_only_memory_region { +void Cleanup(TF_ReadOnlyMemoryRegion* region); +const void* Data(const TF_ReadOnlyMemoryRegion* region); +uint64_t Length(const TF_ReadOnlyMemoryRegion* region); +} // namespace tf_read_only_memory_region + +namespace tf_s3_filesystem { +typedef struct S3File { + std::shared_ptr s3_client; + std::shared_ptr executor; + // We need 2 `TransferManager`, for multipart upload/download. + Aws::Map> + transfer_managers; + // Sizes to split objects during multipart upload/download. + Aws::Map multi_part_chunk_sizes; + bool use_multi_part_download; + absl::Mutex initialization_lock; + S3File(); +} S3File; + +void Init(TF_Filesystem* filesystem, TF_Status* status); +void Cleanup(TF_Filesystem* filesystem); +void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path, + TF_RandomAccessFile* file, TF_Status* status); +void NewWritableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status); +void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status); +int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem, + const char* path, + TF_ReadOnlyMemoryRegion* region, + TF_Status* status); +void PathExists(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void CreateDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +int GetChildren(const TF_Filesystem* filesystem, const char* path, + char*** entries, TF_Status* status); +void DeleteFile(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void Stat(const TF_Filesystem* filesystem, const char* path, + TF_FileStatistics* stats, TF_Status* status); +void DeleteDir(const TF_Filesystem* filesystem, const char* path, + TF_Status* status); +void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, + TF_Status* status); +void RenameFile(const TF_Filesystem* filesystem, const char* src, + const char* dst, TF_Status* status); +} // namespace tf_s3_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem_test.cc b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem_test.cc new file mode 100644 index 00000000000..7dc80fb11ed --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem_test.cc @@ -0,0 +1,540 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h" + +#include +#include + +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/stacktrace_handler.h" +#include "tensorflow/core/platform/test.h" + +#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) +#define EXPECT_TF_OK(x) EXPECT_EQ(TF_OK, TF_GetCode(x)) << TF_Message(x) + +static std::string InitializeTmpDir() { + // This env should be something like `s3://bucket/path` + const char* test_dir = getenv("S3_TEST_TMPDIR"); + if (test_dir != nullptr) { + Aws::String bucket, object; + TF_Status* status = TF_NewStatus(); + ParseS3Path(test_dir, true, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) { + TF_DeleteStatus(status); + return ""; + } + TF_DeleteStatus(status); + + // We add a random value into `test_dir` to ensures that two consecutive + // runs are unlikely to clash. + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> distribution; + std::string rng_val = std::to_string(distribution(gen)); + return tensorflow::io::JoinPath(std::string(test_dir), rng_val); + } else { + return ""; + } +} + +static std::string GetLocalLargeFile() { + // This env is used when we want to test against a large file ( ~ 50MB ). + // `S3_TEST_LOCAL_LARGE_FILE` and `S3_TEST_SERVER_LARGE_FILE` must be the same + // file. + static std::string path; + if (path.empty()) { + const char* env = getenv("S3_TEST_LOCAL_LARGE_FILE"); + if (env == nullptr) return ""; + path = env; + } + return path; +} + +static std::string GetServerLargeFile() { + // This env is used when we want to test against a large file ( ~ 50MB ). + // `S3_TEST_LOCAL_LARGE_FILE` and `S3_TEST_SERVER_LARGE_FILE` must be the same + // file. + static std::string path; + if (path.empty()) { + const char* env = getenv("S3_TEST_SERVER_LARGE_FILE"); + if (env == nullptr) return ""; + Aws::String bucket, object; + TF_Status* status = TF_NewStatus(); + ParseS3Path(env, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) { + TF_DeleteStatus(status); + return ""; + } + TF_DeleteStatus(status); + path = env; + } + return path; +} + +static std::string* GetTmpDir() { + static std::string tmp_dir = InitializeTmpDir(); + if (tmp_dir == "") + return nullptr; + else + return &tmp_dir; +} + +namespace tensorflow { +namespace { + +class S3FilesystemTest : public ::testing::Test { + public: + void SetUp() override { + root_dir_ = io::JoinPath( + *GetTmpDir(), + ::testing::UnitTest::GetInstance()->current_test_info()->name()); + status_ = TF_NewStatus(); + filesystem_ = new TF_Filesystem; + tf_s3_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_) << "Could not initialize filesystem. " + << TF_Message(status_); + } + void TearDown() override { + TF_DeleteStatus(status_); + tf_s3_filesystem::Cleanup(filesystem_); + delete filesystem_; + } + + std::string GetURIForPath(const std::string& path) { + const std::string translated_name = + tensorflow::io::JoinPath(root_dir_, path); + return translated_name; + } + + std::unique_ptr + GetWriter() { + std::unique_ptr writer( + new TF_WritableFile, [](TF_WritableFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) tf_writable_file::Cleanup(file); + delete file; + } + }); + writer->plugin_file = nullptr; + return writer; + } + + std::unique_ptr + GetReader() { + std::unique_ptr + reader(new TF_RandomAccessFile, [](TF_RandomAccessFile* file) { + if (file != nullptr) { + if (file->plugin_file != nullptr) + tf_random_access_file::Cleanup(file); + delete file; + } + }); + reader->plugin_file = nullptr; + return reader; + } + + void WriteString(const std::string& path, const std::string& content) { + auto writer = GetWriter(); + tf_s3_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(), + status_); + if (TF_GetCode(status_) != TF_OK) return; + tf_writable_file::Append(writer.get(), content.c_str(), content.length(), + status_); + if (TF_GetCode(status_) != TF_OK) return; + tf_writable_file::Close(writer.get(), status_); + if (TF_GetCode(status_) != TF_OK) return; + } + + std::string ReadAll(const string& path) { + auto reader = GetReader(); + tf_s3_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + auto file_size = + tf_s3_filesystem::GetFileSize(filesystem_, path.c_str(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + std::string content; + content.resize(file_size); + auto read = tf_random_access_file::Read(reader.get(), 0, file_size, + &content[0], status_); + if (TF_GetCode(status_) != TF_OK) return ""; + if (read >= 0) content.resize(read); + if (file_size != content.size()) + TF_SetStatus( + status_, TF_DATA_LOSS, + std::string("expected " + std::to_string(file_size) + " got " + + std::to_string(content.size()) + " bytes") + .c_str()); + return content; + } + + std::string ReadAllInChunks(const string& path, size_t buffer_size, + bool use_multi_part_download) { + auto reader = GetReader(); + auto s3_file = + static_cast(filesystem_->plugin_filesystem); + s3_file->use_multi_part_download = use_multi_part_download; + s3_file + ->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::DOWNLOAD] = + buffer_size; + tf_s3_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), + reader.get(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + auto file_size = + tf_s3_filesystem::GetFileSize(filesystem_, path.c_str(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + + std::size_t part_count = (std::max)( + static_cast((file_size + buffer_size - 1) / buffer_size), + static_cast(1)); + std::unique_ptr buffer{new char[buffer_size]}; + std::stringstream ss; + + uint64_t offset = 0; + uint64_t server_size = 0; + for (size_t i = 0; i < part_count; i++) { + offset = i * buffer_size; + buffer_size = + (i == part_count - 1) ? file_size - server_size : buffer_size; + auto read = tf_random_access_file::Read(reader.get(), offset, buffer_size, + buffer.get(), status_); + if (TF_GetCode(status_) != TF_OK) return ""; + if (read > 0) { + ss.write(buffer.get(), read); + server_size += static_cast(read); + } + if (server_size == file_size) break; + if (read != buffer_size) { + if (read == 0) + TF_SetStatus(status_, TF_OUT_OF_RANGE, "eof"); + else + TF_SetStatus( + status_, TF_DATA_LOSS, + ("truncated record at " + std::to_string(offset)).c_str()); + return ""; + } + } + + if (file_size != server_size) { + TF_SetStatus(status_, TF_DATA_LOSS, + std::string("expected " + std::to_string(file_size) + + " got " + std::to_string(server_size) + " bytes") + .c_str()); + return ""; + } + TF_SetStatus(status_, TF_OK, ""); + return ss.str(); + } + + protected: + TF_Filesystem* filesystem_; + TF_Status* status_; + + private: + std::string root_dir_; +}; + +TEST_F(S3FilesystemTest, NewRandomAccessFile) { + const std::string path = GetURIForPath("RandomAccessFile"); + const std::string content = "abcdefghijklmn"; + + WriteString(path, content); + ASSERT_TF_OK(status_); + + auto reader = GetReader(); + tf_s3_filesystem::NewRandomAccessFile(filesystem_, path.c_str(), reader.get(), + status_); + EXPECT_TF_OK(status_); + + std::string result; + result.resize(content.size()); + auto read = tf_random_access_file::Read(reader.get(), 0, content.size(), + &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(content.size(), result.size()); + EXPECT_EQ(content, result); + + result.clear(); + result.resize(4); + read = tf_random_access_file::Read(reader.get(), 2, 4, &result[0], status_); + result.resize(read); + EXPECT_TF_OK(status_); + EXPECT_EQ(4, result.size()); + EXPECT_EQ(content.substr(2, 4), result); +} + +TEST_F(S3FilesystemTest, NewWritableFile) { + auto writer = GetWriter(); + const std::string path = GetURIForPath("WritableFile"); + tf_s3_filesystem::NewWritableFile(filesystem_, path.c_str(), writer.get(), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Append(writer.get(), "content1,", strlen("content1,"), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Append(writer.get(), "content2", strlen("content2"), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Flush(writer.get(), status_); + EXPECT_TF_OK(status_); + tf_writable_file::Sync(writer.get(), status_); + EXPECT_TF_OK(status_); + tf_writable_file::Close(writer.get(), status_); + EXPECT_TF_OK(status_); + + auto content = ReadAll(path); + EXPECT_TF_OK(status_); + EXPECT_EQ("content1,content2", content); +} + +TEST_F(S3FilesystemTest, NewAppendableFile) { + const std::string path = GetURIForPath("AppendableFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + + auto writer = GetWriter(); + tf_s3_filesystem::NewAppendableFile(filesystem_, path.c_str(), writer.get(), + status_); + EXPECT_TF_OK(status_); + tf_writable_file::Append(writer.get(), "content", strlen("content"), status_); + EXPECT_TF_OK(status_); + tf_writable_file::Close(writer.get(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(S3FilesystemTest, NewReadOnlyMemoryRegionFromFile) { + const std::string path = GetURIForPath("MemoryFile"); + const std::string content = "content"; + WriteString(path, content); + ASSERT_TF_OK(status_); + + std::unique_ptr + region(new TF_ReadOnlyMemoryRegion, [](TF_ReadOnlyMemoryRegion* file) { + if (file != nullptr) { + if (file->plugin_memory_region != nullptr) + tf_read_only_memory_region::Cleanup(file); + delete file; + } + }); + region->plugin_memory_region = nullptr; + tf_s3_filesystem::NewReadOnlyMemoryRegionFromFile(filesystem_, path.c_str(), + region.get(), status_); + EXPECT_TF_OK(status_); + std::string result(reinterpret_cast( + tf_read_only_memory_region::Data(region.get())), + tf_read_only_memory_region::Length(region.get())); + EXPECT_EQ(content, result); +} + +TEST_F(S3FilesystemTest, PathExists) { + const std::string path = GetURIForPath("PathExists"); + tf_s3_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status_)) << TF_Message(status_); + TF_SetStatus(status_, TF_OK, ""); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + tf_s3_filesystem::PathExists(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(S3FilesystemTest, GetChildren) { + const std::string base = GetURIForPath("GetChildren"); + tf_s3_filesystem::CreateDir(filesystem_, base.c_str(), status_); + EXPECT_TF_OK(status_); + + const std::string file = io::JoinPath(base, "TestFile.csv"); + WriteString(file, "test"); + EXPECT_TF_OK(status_); + + const std::string subdir = io::JoinPath(base, "SubDir"); + tf_s3_filesystem::CreateDir(filesystem_, subdir.c_str(), status_); + EXPECT_TF_OK(status_); + const std::string subfile = io::JoinPath(subdir, "TestSubFile.csv"); + WriteString(subfile, "test"); + EXPECT_TF_OK(status_); + + char** entries; + auto num_entries = tf_s3_filesystem::GetChildren(filesystem_, base.c_str(), + &entries, status_); + EXPECT_TF_OK(status_); + + std::vector childrens; + for (int i = 0; i < num_entries; ++i) { + childrens.push_back(entries[i]); + } + std::sort(childrens.begin(), childrens.end()); + EXPECT_EQ(std::vector({"SubDir", "TestFile.csv"}), childrens); +} + +TEST_F(S3FilesystemTest, DeleteFile) { + const std::string path = GetURIForPath("DeleteFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + tf_s3_filesystem::DeleteFile(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); +} + +TEST_F(S3FilesystemTest, CreateDir) { + // s3 object storage doesn't support empty directory, we create file in the + // directory + const std::string dir = GetURIForPath("CreateDir"); + tf_s3_filesystem::CreateDir(filesystem_, dir.c_str(), status_); + EXPECT_TF_OK(status_); + + const std::string file = io::JoinPath(dir, "CreateDirFile.csv"); + WriteString(file, "test"); + ASSERT_TF_OK(status_); + + TF_FileStatistics stat; + tf_s3_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_); + EXPECT_TF_OK(status_); + EXPECT_TRUE(stat.is_directory); +} + +TEST_F(S3FilesystemTest, DeleteDir) { + // s3 object storage doesn't support empty directory, we create file in the + // directory + const std::string dir = GetURIForPath("DeleteDir"); + const std::string file = io::JoinPath(dir, "DeleteDirFile.csv"); + WriteString(file, "test"); + ASSERT_TF_OK(status_); + tf_s3_filesystem::DeleteDir(filesystem_, dir.c_str(), status_); + EXPECT_NE(TF_GetCode(status_), TF_OK); + + TF_SetStatus(status_, TF_OK, ""); + tf_s3_filesystem::DeleteFile(filesystem_, file.c_str(), status_); + EXPECT_TF_OK(status_); + tf_s3_filesystem::DeleteDir(filesystem_, dir.c_str(), status_); + EXPECT_TF_OK(status_); + TF_FileStatistics stat; + tf_s3_filesystem::Stat(filesystem_, dir.c_str(), &stat, status_); + EXPECT_EQ(TF_GetCode(status_), TF_NOT_FOUND) << TF_Message(status_); +} + +TEST_F(S3FilesystemTest, StatFile) { + const std::string path = GetURIForPath("StatFile"); + WriteString(path, "test"); + ASSERT_TF_OK(status_); + + TF_FileStatistics stat; + tf_s3_filesystem::Stat(filesystem_, path.c_str(), &stat, status_); + EXPECT_TF_OK(status_); + EXPECT_EQ(4, stat.length); + EXPECT_FALSE(stat.is_directory); +} + +TEST_F(S3FilesystemTest, SimpleCopyFile) { + const std::string src = GetURIForPath("SimpleCopySrc"); + const std::string dst = GetURIForPath("SimpleCopyDst"); + WriteString(src, "test"); + ASSERT_TF_OK(status_); + + tf_s3_filesystem::CopyFile(filesystem_, src.c_str(), dst.c_str(), status_); + EXPECT_TF_OK(status_); + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ(result, "test"); +} + +TEST_F(S3FilesystemTest, RenameFile) { + const std::string src = GetURIForPath("RenameFileSrc"); + const std::string dst = GetURIForPath("RenameFileDst"); + WriteString(src, "test"); + ASSERT_TF_OK(status_); + + tf_s3_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_); + EXPECT_TF_OK(status_); + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ("test", result); +} + +TEST_F(S3FilesystemTest, RenameFileOverwrite) { + const std::string src = GetURIForPath("RenameFileOverwriteSrc"); + const std::string dst = GetURIForPath("RenameFileOverwriteDst"); + + WriteString(src, "test_old"); + ASSERT_TF_OK(status_); + WriteString(dst, "test_new"); + ASSERT_TF_OK(status_); + + tf_s3_filesystem::PathExists(filesystem_, dst.c_str(), status_); + EXPECT_TF_OK(status_); + tf_s3_filesystem::RenameFile(filesystem_, src.c_str(), dst.c_str(), status_); + EXPECT_TF_OK(status_); + + auto result = ReadAll(dst); + EXPECT_TF_OK(status_); + EXPECT_EQ("test_old", result); +} + +// Test against large file. +TEST_F(S3FilesystemTest, ReadLargeFile) { + auto local_path = GetLocalLargeFile(); + auto server_path = GetServerLargeFile(); + if (local_path.empty() || server_path.empty()) GTEST_SKIP(); + std::ifstream in(local_path, std::ios::binary); + std::string local_content((std::istreambuf_iterator(in)), + std::istreambuf_iterator()); + + constexpr size_t buffer_size = 50 * 1024 * 1024; + auto server_content = ReadAllInChunks(server_path, buffer_size, true); + ASSERT_TF_OK(status_); + EXPECT_EQ(local_content, server_content); + + server_content = ReadAllInChunks(server_path, buffer_size, false); + ASSERT_TF_OK(status_); + EXPECT_EQ(local_content, server_content); +} + +TEST_F(S3FilesystemTest, CopyLargeFile) { + auto server_path = GetServerLargeFile(); + if (server_path.empty()) GTEST_SKIP(); + + auto path = GetURIForPath("CopyLargeFile"); + constexpr size_t buffer_size = 5 * 1024 * 1024; + auto s3_file = + static_cast(filesystem_->plugin_filesystem); + s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD] = + buffer_size; + tf_s3_filesystem::CopyFile(filesystem_, server_path.c_str(), path.c_str(), + status_); + EXPECT_TF_OK(status_); + + auto server_size = + tf_s3_filesystem::GetFileSize(filesystem_, server_path.c_str(), status_); + EXPECT_TF_OK(status_); + auto actual_size = + tf_s3_filesystem::GetFileSize(filesystem_, path.c_str(), status_); + EXPECT_TF_OK(status_); + EXPECT_EQ(server_size, actual_size); +} + +} // namespace +} // namespace tensorflow + +GTEST_API_ int main(int argc, char** argv) { + tensorflow::testing::InstallStacktraceHandler(); + if (!GetTmpDir()) { + std::cerr << "Could not read S3_TEST_TMPDIR env"; + return -1; + } + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD new file mode 100644 index 00000000000..80c4e8d9791 --- /dev/null +++ b/tensorflow/c/experimental/gradients/BUILD @@ -0,0 +1,24 @@ +# Library of gradient functions. +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "math_grad", + srcs = ["math_grad.cc"], + hdrs = [ + "math_grad.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:gradients", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/core/lib/llvm_rtti", + ], +) diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc new file mode 100644 index 00000000000..d8b70848d4e --- /dev/null +++ b/tensorflow/c/experimental/gradients/math_grad.cc @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/math_grad.h" + +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" + +using std::vector; +using tensorflow::ops::Conj; +using tensorflow::ops::Identity; +using tensorflow::ops::Mul; + +namespace tensorflow { +namespace gradients { +namespace { + +class AddGradientFunction : public GradientFunction { + public: + Status Compute(Context* ctx, + absl::Span grad_inputs, + vector* grad_outputs) override { + grad_outputs->resize(2); + vector identity_outputs(1); + // TODO(b/145674566): Handle name unification in tracing code. + // TODO(b/161805092): Support broadcasting. + TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, + absl::MakeSpan(identity_outputs), + "Identity0")); + (*grad_outputs)[0] = identity_outputs[0]; + TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, + absl::MakeSpan(identity_outputs), + "Identity1")); + (*grad_outputs)[1] = identity_outputs[0]; + return Status::OK(); + } + ~AddGradientFunction() override {} +}; + +class ExpGradientFunction : public GradientFunction { + public: + explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) { + exp->Ref(); + } + Status Compute(Context* ctx, + absl::Span grad_inputs, + vector* grad_outputs) override { + vector conj_outputs(1); + TF_RETURN_IF_ERROR( + Conj(ctx->ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), "ExpConj")); + AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]); + grad_outputs->resize(1); + TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]}, + absl::MakeSpan(*grad_outputs), "ExpGradMul")); + return Status::OK(); + } + ~ExpGradientFunction() override {} + + private: + AbstractTensorHandlePtr exp_; +}; + +} // namespace + +GradientFunction* AddRegisterer(const ForwardOperation& op) { + return new AddGradientFunction; +} + +GradientFunction* ExpRegisterer(const ForwardOperation& op) { + return new ExpGradientFunction(op.outputs[0]); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/math_grad.h b/tensorflow/c/experimental/gradients/math_grad.h new file mode 100644 index 00000000000..6c7242a1a49 --- /dev/null +++ b/tensorflow/c/experimental/gradients/math_grad.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ + +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +GradientFunction* AddRegisterer(const ForwardOperation& op); +GradientFunction* ExpRegisterer(const ForwardOperation& op); +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ diff --git a/tensorflow/c/experimental/network.cc b/tensorflow/c/experimental/network.cc deleted file mode 100644 index 97e63ec6259..00000000000 --- a/tensorflow/c/experimental/network.cc +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/c/experimental/network.h" - -#include -#include - -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/c_api_internal.h" -#include "tensorflow/c/experimental/network_internal.h" -#include "tensorflow/c/experimental/rendezvous_internal.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" -#include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" - -using tensorflow::ServerFactory; - -namespace tensorflow { - -/* static */ Status CGrpcServer::Create( - const ServerDef& server_def, - void* (*init_function)(const TF_GrpcServer*, TF_Status*), - void (*start_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*delete_function)(void*), - TF_RemoteRendezvousBuilder* rendezvous_builder, - std::unique_ptr* out_server) { - auto* grpc_server = new CGrpcServer(server_def, start_function, stop_function, - join_function, delete_function); - - GrpcServerOptions options; - options.rendezvous_mgr_func = [rendezvous_builder](const WorkerEnv* env) { - return new CRendezvousMgr(env, rendezvous_builder); - }; - TF_RETURN_IF_ERROR(grpc_server->Init(options)); - TF_Status* tf_status = TF_NewStatus(); - grpc_server->SetContext(init_function( - reinterpret_cast(grpc_server), tf_status)); - TF_RETURN_IF_ERROR(tf_status->status); - TF_DeleteStatus(tf_status); - - out_server->reset(grpc_server); - return Status::OK(); -} - -Status CGrpcServer::Start() { - Status status = GrpcServer::Start(); - TF_Status* tf_status = TF_NewStatus(); - (*start_function_)(reinterpret_cast(this), context_, - tf_status); - status.Update(tf_status->status); - TF_DeleteStatus(tf_status); - return status; -} - -Status CGrpcServer::Stop() { - Status status = GrpcServer::Stop(); - TF_Status* tf_status = TF_NewStatus(); - (*stop_function_)(reinterpret_cast(this), context_, - tf_status); - status.Update(tf_status->status); - TF_DeleteStatus(tf_status); - return status; -} - -Status CGrpcServer::Join() { - Status status = GrpcServer::Join(); - TF_Status* tf_status = TF_NewStatus(); - (*join_function_)(reinterpret_cast(this), context_, - tf_status); - status.Update(tf_status->status); - TF_DeleteStatus(tf_status); - return status; -} - -namespace { -// Factory that creates CGrpcServer instances. -class CServerFactory : public ServerFactory { - public: - CServerFactory(bool (*accept_function)(const char*), - void* (*init_function)(const TF_GrpcServer*, TF_Status*), - void (*start_function)(const TF_GrpcServer*, void*, - TF_Status*), - void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*delete_function)(void*), - TF_RemoteRendezvousBuilder* rendezvous_builder) - : accept_function_(accept_function), - init_function_(init_function), - start_function_(start_function), - stop_function_(stop_function), - join_function_(join_function), - delete_function_(delete_function), - rendezvous_builder_(rendezvous_builder) {} - - Status NewServer(const ServerDef& server_def, const Options& options, - std::unique_ptr* out_server) override { - TF_RETURN_IF_ERROR(CGrpcServer::Create( - server_def, init_function_, start_function_, stop_function_, - join_function_, delete_function_, rendezvous_builder_, out_server)); - return Status::OK(); - } - - // Returns true if and only if this factory can create a server - // based on the given `server_def`. - bool AcceptsOptions(const ServerDef& server_def) override { - return (*accept_function_)(server_def.protocol().c_str()); - } - - private: - bool (*accept_function_)(const char* protocol); - void* (*init_function_)(const TF_GrpcServer*, TF_Status*); - void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*); - void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*); - void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*); - void (*delete_function_)(void*); - TF_RemoteRendezvousBuilder* rendezvous_builder_; -}; -} // namespace -} // namespace tensorflow - -// Server factory representation to use in C API. -// Holds CServerFactory pointer. -struct TF_GrpcServerFactory { - ::tensorflow::CServerFactory* factory; -}; - -TF_GrpcServerFactory* TF_NewGrpcServerFactory( - bool (*accept_function)(const char*), - void* (*init_function)(const TF_GrpcServer*, TF_Status*), - void (*start_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*delete_function)(void*), - TF_RemoteRendezvousBuilder* rendezvous_builder) { - TF_GrpcServerFactory* server_factory = new TF_GrpcServerFactory; - server_factory->factory = new ::tensorflow::CServerFactory( - accept_function, init_function, start_function, stop_function, - join_function, delete_function, rendezvous_builder); - return server_factory; -} - -void TF_DeleteGrpcServerFactory(TF_GrpcServerFactory* server_factory) { - DCHECK_NE(server_factory, nullptr); - delete server_factory; -} - -void TF_RegisterGrpcServerFactory(const char* server_type, - TF_GrpcServerFactory* server_factory) { - ServerFactory::Register(server_type, server_factory->factory); -} diff --git a/tensorflow/c/experimental/network.h b/tensorflow/c/experimental/network.h deleted file mode 100644 index bd74ec8ffec..00000000000 --- a/tensorflow/c/experimental/network.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_ -#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_ - -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/experimental/rendezvous.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// -------------------------------------------------------------------------- -// C API for TensorFlow Networking. -// NOTE: This API is unstable and almost certainly will change in the near -// future. -// -// Users wishing to register a custom GrpcServer should call -// TF_NewServerFactory and then TF_RegisterGrpcServerFactory. -// -// Example: -// ```c++ -// auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder( -// rendezvous_init_function, -// receive_from_remote_async_function, -// rendezvous_delete_function); -// -// TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory( -// accept_function, -// init_function, -// start_function, -// stop_function, -// join_function, -// delete_function, -// rendezvous_builder); -// TF_RegisterGrpcServerFactory("customfactory", factory); -// ... -// TF_DeleteGrpcServerFactory(factory); -// ``` - -typedef struct TF_GrpcServerFactory TF_GrpcServerFactory; -typedef struct TF_GrpcServerOptions TF_GrpcServerOptions; -typedef struct TF_GrpcServer TF_GrpcServer; -typedef struct TF_ServerContext { - TF_GrpcServer* const server; - void* context; -} TF_ServerContext; - -// Creates a new TF_GrpcServerFactory instance. Caller takes ownership -// of TF_GrpcServerFactory instance and should deallocate it by calling -// TF_GrpcDeleteServerFactory. -// accept_function should return true if this ServerFactory can create -// server instances for the given protocol name (for e.g. grpc+verbs). -// GRPC servers created by this factory will call provided -// init_function, start_function, stop_function, join_function and -// delete_function. -// -// Note that clean shutdown is currently not implemented for GrpcServer. -// So, stop_function will never be called now but may be in the future -// when stop mechanism is supported. -TF_CAPI_EXPORT extern TF_GrpcServerFactory* TF_NewGrpcServerFactory( - bool (*accept_function)(const char*), - void* (*init_function)(const TF_GrpcServer*, TF_Status*), - void (*start_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*delete_function)(void*), - TF_RemoteRendezvousBuilder* rendezvous_builder); - -// Deletes TF_GrpcServerFactory instances. -// Note that this function only deletes TF_GrpcServerFactory wrapper. -// Actual underlying server factory would not be deleted and will -// remain registered. -TF_CAPI_EXPORT extern void TF_DeleteGrpcServerFactory( - TF_GrpcServerFactory* server_factory); - -// Registers provided server_factory for the given server_type. -// server_type must be unique to the server factory. -TF_CAPI_EXPORT extern void TF_RegisterGrpcServerFactory( - const char* server_type, TF_GrpcServerFactory* server_factory); - -#ifdef __cplusplus -} /* end extern "C" */ -#endif -#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_H_ diff --git a/tensorflow/c/experimental/network_internal.h b/tensorflow/c/experimental/network_internal.h deleted file mode 100644 index 389de440b70..00000000000 --- a/tensorflow/c/experimental/network_internal.h +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_ -#define TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_ - -#include - -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/experimental/network.h" -#include "tensorflow/c/experimental/rendezvous.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/protobuf/tensorflow_server.pb.h" - -namespace tensorflow { - -// GrpcServer implementation that forwards calls to callbacks. -class CGrpcServer : public GrpcServer { - protected: - CGrpcServer(const ServerDef& server_def, - void (*start_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*delete_function)(void*)) - : GrpcServer(server_def, ::tensorflow::Env::Default()), - start_function_(start_function), - stop_function_(stop_function), - join_function_(join_function), - delete_function_(delete_function), - context_(nullptr) {} - - public: - static Status Create( - const ServerDef& server_def, - void* (*init_function)(const TF_GrpcServer*, TF_Status*), - void (*start_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*stop_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*join_function)(const TF_GrpcServer*, void*, TF_Status*), - void (*delete_function)(void*), - TF_RemoteRendezvousBuilder* rendezvous_builder, - std::unique_ptr* out_server); - - Status Start() override; - Status Stop() override; - Status Join() override; - - ~CGrpcServer() override { delete_function_(context_); } - - protected: - void SetContext(void* context) { context_ = context; } - - private: - void (*start_function_)(const TF_GrpcServer*, void*, TF_Status*); - void (*stop_function_)(const TF_GrpcServer*, void*, TF_Status*); - void (*join_function_)(const TF_GrpcServer*, void*, TF_Status*); - void (*delete_function_)(void*); - void* context_; - - friend class NetworksTest; -}; - -} // namespace tensorflow -#endif // TENSORFLOW_C_EXPERIMENTAL_NETWORK_INTERNAL_H_ diff --git a/tensorflow/c/experimental/network_test.cc b/tensorflow/c/experimental/network_test.cc deleted file mode 100644 index b7a50008c37..00000000000 --- a/tensorflow/c/experimental/network_test.cc +++ /dev/null @@ -1,256 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/c/experimental/network.h" - -#include -#include -#include - -#include -#include - -#include "absl/synchronization/notification.h" -#include "absl/time/time.h" -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/experimental/network_internal.h" -#include "tensorflow/c/experimental/rendezvous.h" -#include "tensorflow/c/experimental/rendezvous_internal.h" -#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" -#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" -#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h" -#include "tensorflow/core/distributed_runtime/server_lib.h" -#include "tensorflow/core/distributed_runtime/session_mgr.h" -#include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/core/distributed_runtime/worker_session.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/rendezvous.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/strcat.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/protobuf/cluster.pb.h" -#include "tensorflow/core/protobuf/tensorflow_server.pb.h" - -namespace tensorflow { - -bool accept_functionA(const char* protocol_name) { - return strcmp(protocol_name, "grpc+A") == 0; -} - -bool accept_functionB(const char* protocol_name) { - return strcmp(protocol_name, "grpc+B") == 0; -} - -struct SomeServerData { - bool server_started = false; -}; - -struct SomeRendezvousData { - int test = 0; -}; - -void* init_function(const TF_GrpcServer* server, TF_Status* status) { - SomeServerData* server_data = new SomeServerData(); - TF_SetStatus(status, TF_OK, ""); - return server_data; -} - -void start_function(const TF_GrpcServer* server, void* context, - TF_Status* status) { - auto* server_data = static_cast(context); - server_data->server_started = true; - TF_SetStatus(status, TF_OK, ""); -} - -void stop_function(const TF_GrpcServer* server, void* context, - TF_Status* status) { - TF_SetStatus(status, TF_OK, ""); -} - -void join_function(const TF_GrpcServer* server, void* context, - TF_Status* status) { - TF_SetStatus(status, TF_OK, ""); -} - -void delete_function(void* context) { - auto* server_data = static_cast(context); - delete server_data; -} - -void* rendezvous_init_function(void* server_context) { - return new SomeRendezvousData(); -} - -void Deallocator(void* data, size_t, void* arg) { - tensorflow::cpu_allocator()->DeallocateRaw(data); - *reinterpret_cast(arg) = true; -} - -void receive_from_remote_async_function(TF_ParsedKey* key, - TF_RendezvousArgs* args, - TF_RendezvousDoneCallback* callback, - void* context) { - // Create dummy tensor - const int num_bytes = 6 * sizeof(float); - float* values = - reinterpret_cast(tensorflow::cpu_allocator()->AllocateRaw( - EIGEN_MAX_ALIGN_BYTES, num_bytes)); - int64_t dims[] = {2, 3}; - bool deallocator_called = false; - auto* tensor = TF_NewTensor(TF_FLOAT, dims, 2, values, num_bytes, - &Deallocator, &deallocator_called); - callback->tensor = tensor; - auto* tf_status = TF_NewStatus(); - TF_SetStatus(tf_status, TF_OK, ""); - callback->status = tf_status; - TF_RendezvousDone(callback); - TF_DeleteStatus(tf_status); - TF_DeleteTensor(tensor); -} - -void rendezvous_delete_function(void* context) { - auto* rendezvous_data = static_cast(context); - delete rendezvous_data; -} - -tensorflow::ServerDef GetServerDef(const string& protocol, - const string& job_name, int num_tasks) { - tensorflow::ServerDef server_def; - server_def.set_protocol(protocol); - server_def.set_job_name(job_name); - server_def.set_task_index(0); - tensorflow::ClusterDef* cluster_def = server_def.mutable_cluster(); - tensorflow::JobDef* job_def = cluster_def->add_job(); - job_def->set_name(job_name); - for (int i = 0; i < num_tasks; i++) { - int port = tensorflow::testing::PickUnusedPortOrDie(); - job_def->mutable_tasks()->insert( - {i, tensorflow::strings::StrCat("localhost:", port)}); - } - return server_def; -} - -class NetworksTest : public ::testing::Test { - public: - ~NetworksTest() override {} - - SomeServerData* GetServerData(CGrpcServer* server) { - EXPECT_NE(server->context_, nullptr); - return static_cast(server->context_); - } -}; - -Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, - const string& receiver, const string& name) { - Rendezvous::ParsedKey result; - CHECK( - Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, - name, FrameAndIter(0, 0)), - &result) - .ok()); - return result; -} - -void InitializeRendezvous(GrpcServer* grpc_server, ServerDef* server_def, - RemoteRendezvous* remote_rendezvous) { - int rendezvous_id = 0; - auto session_name = tensorflow::strings::StrCat("test_", rendezvous_id); - TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->CreateSession( - session_name, *server_def, true)); - - std::shared_ptr worker_session; - TF_EXPECT_OK(grpc_server->worker_env()->session_mgr->WorkerSessionForSession( - session_name, &worker_session)); - - TF_EXPECT_OK(remote_rendezvous->Initialize(worker_session.get())); -} - -TEST_F(NetworksTest, TestStartServer) { - auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder( - rendezvous_init_function, receive_from_remote_async_function, - rendezvous_delete_function); - - TF_Status* tf_status = TF_NewStatus(); - TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory( - accept_functionA, init_function, start_function, stop_function, - join_function, delete_function, rendezvous_builder); - TF_RegisterGrpcServerFactory("testfactoryA", factory); - - ServerDef server_def = GetServerDef("grpc+A", "localhost", 1); - std::unique_ptr server; - TF_EXPECT_OK(NewServer(server_def, &server)); - auto* grpc_server = static_cast(server.get()); - auto* server_data = GetServerData(grpc_server); - ASSERT_FALSE(server_data->server_started); - - TF_EXPECT_OK(server->Start()); - ASSERT_TRUE(server_data->server_started); - - TF_DeleteStatus(tf_status); - TF_DeleteGrpcServerFactory(factory); - TF_DeleteRemoteRendezvousBuilder(rendezvous_builder); - // TODO(annarev): find a clean way to shutdown server. - server.release(); -} - -TEST_F(NetworksTest, TestReceiveData) { - auto* rendezvous_builder = TF_NewRemoteRendezvousBuilder( - rendezvous_init_function, receive_from_remote_async_function, - rendezvous_delete_function); - - TF_Status* tf_status = TF_NewStatus(); - TF_GrpcServerFactory* factory = TF_NewGrpcServerFactory( - accept_functionB, init_function, start_function, stop_function, - join_function, delete_function, rendezvous_builder); - TF_RegisterGrpcServerFactory("testfactoryB", factory); - - ServerDef server_def = GetServerDef("grpc+B", "localhost", 1); - std::unique_ptr server; - TF_EXPECT_OK(NewServer(server_def, &server)); - auto* grpc_server = static_cast(server.get()); - - TF_EXPECT_OK(server->Start()); - auto* rendezvous_mgr = grpc_server->worker_env()->rendezvous_mgr; - auto* remote_rendezvous = rendezvous_mgr->Find(0); - - auto key = Key("/job:localhost/replica:1/task:2/device:CPU:0", 1, - "/job:localhost/replica:0/task:0/device:CPU:0", "test"); - Rendezvous::Args args; - bool done_callback_called = false; - auto* done_callback_called_ptr = &done_callback_called; - absl::Notification notification; - auto* notification_ptr = ¬ification; - - InitializeRendezvous(grpc_server, &server_def, remote_rendezvous); - remote_rendezvous->RecvAsync( - key, args, - [done_callback_called_ptr, notification_ptr]( - const Status&, const Rendezvous::Args&, const Rendezvous::Args&, - const Tensor&, const bool) mutable { - *done_callback_called_ptr = true; - notification_ptr->Notify(); - }); - notification.WaitForNotificationWithTimeout(absl::Seconds(10)); - ASSERT_EQ(done_callback_called, true); - - TF_DeleteStatus(tf_status); - TF_DeleteGrpcServerFactory(factory); - TF_DeleteRemoteRendezvousBuilder(rendezvous_builder); - // Server doesn't have a clean shutdown. - server.release(); -} - -} // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD new file mode 100644 index 00000000000..d13d7a72d3e --- /dev/null +++ b/tensorflow/c/experimental/ops/BUILD @@ -0,0 +1,48 @@ +# Experimental ops. These will eventually be replaced by machine-generated versions. +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "array_ops", + srcs = [ + "array_ops.cc", + ], + hdrs = [ + "array_ops.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:errors", + ], +) + +cc_library( + name = "math_ops", + srcs = [ + "math_ops.cc", + ], + hdrs = [ + "math_ops.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":array_ops", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:errors", + ], +) diff --git a/tensorflow/c/experimental/ops/array_ops.cc b/tensorflow/c/experimental/ops/array_ops.cc new file mode 100644 index 00000000000..ab2d114d9d9 --- /dev/null +++ b/tensorflow/c/experimental/ops/array_ops.cc @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/ops/array_ops.h" + +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace ops { +// Creates an Identity op. +Status Identity(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr identity_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + identity_op->Reset("Identity", /*raw_device_name=*/nullptr)); + if (isa(identity_op.get())) { + TF_RETURN_IF_ERROR(dyn_cast(identity_op.get()) + ->SetOpName(name)); + } + TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0])); + int num_retvals = 1; + return identity_op->Execute(outputs, &num_retvals); +} + +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/array_ops.h b/tensorflow/c/experimental/ops/array_ops.h new file mode 100644 index 00000000000..226461fd286 --- /dev/null +++ b/tensorflow/c/experimental/ops/array_ops.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_ + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" + +namespace tensorflow { +namespace ops { +Status Identity(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_ diff --git a/tensorflow/c/experimental/ops/math_ops.cc b/tensorflow/c/experimental/ops/math_ops.cc new file mode 100644 index 00000000000..e91acbd6370 --- /dev/null +++ b/tensorflow/c/experimental/ops/math_ops.cc @@ -0,0 +1,55 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/ops/math_ops.h" + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/errors.h" +namespace tensorflow { +namespace ops { +using tensorflow::tracing::TracingOperation; + +Status Mul(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr mul_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr)); + if (isa(mul_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(mul_op.get())->SetOpName(name)); + } + TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1])); + int num_retvals = 1; + return mul_op->Execute(outputs, &num_retvals); +} + +Status Conj(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + auto dtype = inputs[0]->DataType(); + if (DataTypeIsFloating(BaseType(dtype)) || + DataTypeIsInteger(BaseType(dtype))) { + TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name)); + } else { + return errors::Unimplemented("Conj does not support complex types yet."); + } + return Status::OK(); +} + +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.h b/tensorflow/c/experimental/ops/math_ops.h new file mode 100644 index 00000000000..4d7c3d838ce --- /dev/null +++ b/tensorflow/c/experimental/ops/math_ops.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ + +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" + +namespace tensorflow { +namespace ops { +Status Mul(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); +Status Conj(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_ diff --git a/tensorflow/c/experimental/rendezvous.cc b/tensorflow/c/experimental/rendezvous.cc deleted file mode 100644 index c996cfb44f3..00000000000 --- a/tensorflow/c/experimental/rendezvous.cc +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/c/experimental/rendezvous.h" - -#include - -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/c_api_internal.h" -#include "tensorflow/c/experimental/rendezvous_internal.h" -#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/rendezvous.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/stringpiece.h" - -namespace tensorflow { - -CRemoteRendezvous::CRemoteRendezvous(const WorkerEnv* env, int64 step_id, - void (*receive_from_remote_async_function)( - TF_ParsedKey*, TF_RendezvousArgs*, - TF_RendezvousDoneCallback*, - void* context), - void (*delete_function)(void* context), - void* server_context) - : BaseRemoteRendezvous(env, step_id), - receive_from_remote_async_function_(receive_from_remote_async_function), - delete_function_(delete_function), - context_(nullptr) {} - -void CRemoteRendezvous::RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, - const Rendezvous::Args& args, - DoneCallback done) { - if (args.cancellation_manager != nullptr) { - VLOG(1) << "WARNING: CRemoteRendezvous does not support cancellation."; - } - TF_ParsedKey key; - key.src_device = parsed.src_device.data(); - key.src_device_len = parsed.src_device.size(); - key.dst_device = parsed.dst_device.data(); - key.dst_device_len = parsed.dst_device.size(); - key.full_key = parsed.FullKey().data(); - key.full_key_len = parsed.FullKey().size(); - - TF_DeviceContext* device_context = new TF_DeviceContext(); - device_context->context = args.device_context; - - TF_AllocatorAttributes* alloc_attrs = new TF_AllocatorAttributes(); - alloc_attrs->value = args.alloc_attrs.value; - alloc_attrs->scope_id = args.alloc_attrs.scope_id; - alloc_attrs->on_host = args.alloc_attrs.on_host(); - alloc_attrs->nic_compatible = args.alloc_attrs.nic_compatible(); - - TF_RendezvousArgs* cargs = new TF_RendezvousArgs(); - cargs->device_context = device_context; - cargs->alloc_attrs = alloc_attrs; - - TF_RendezvousDoneCallback* done_callback = new TF_RendezvousDoneCallback(); - done_callback->done_callback = done; - done_callback->recv_args = cargs; - - receive_from_remote_async_function_(&key, cargs, done_callback, context_); -} - -CRemoteRendezvous::~CRemoteRendezvous() { delete_function_(context_); } -} // namespace tensorflow - -TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder( - void* (*init_function)(void* server_context), - void (*receive_from_remote_async_function)(TF_ParsedKey*, - TF_RendezvousArgs*, - TF_RendezvousDoneCallback*, - void* context), - void (*delete_function)(void* context)) { - TF_RemoteRendezvousBuilder* builder = new TF_RemoteRendezvousBuilder(); - builder->init_function = init_function; - builder->delete_function = delete_function; - builder->receive_from_remote_async_function = - receive_from_remote_async_function; - return builder; -} - -void TF_DeleteRemoteRendezvousBuilder( - TF_RemoteRendezvousBuilder* rendezvous_builder) { - DCHECK_NE(rendezvous_builder, nullptr); - delete rendezvous_builder; -} - -TF_CAPI_EXPORT extern void TF_RendezvousDone( - TF_RendezvousDoneCallback* callback) { - DCHECK_NE(callback, nullptr); - ::tensorflow::Tensor tensor; - TF_CHECK_OK(TF_TensorToTensor(callback->tensor, &tensor)); - ::tensorflow::Rendezvous::Args recv_args; - recv_args.alloc_attrs.value = callback->recv_args->alloc_attrs->value; - recv_args.alloc_attrs.scope_id = callback->recv_args->alloc_attrs->scope_id; - recv_args.device_context = callback->recv_args->device_context->context; - ::tensorflow::Rendezvous::Args sent_args; - - callback->done_callback(callback->status->status, sent_args, recv_args, - tensor, callback->dead); - - if (callback->recv_args) { - DCHECK_NE(callback->recv_args, nullptr); - DCHECK_NE(callback->recv_args->alloc_attrs, nullptr); - DCHECK_NE(callback->recv_args->device_context, nullptr); - delete callback->recv_args->alloc_attrs; - delete callback->recv_args->device_context; - delete callback->recv_args; - } - delete callback; - callback = nullptr; -} diff --git a/tensorflow/c/experimental/rendezvous.h b/tensorflow/c/experimental/rendezvous.h deleted file mode 100644 index 5b007d52429..00000000000 --- a/tensorflow/c/experimental/rendezvous.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_ -#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_ - -#include "tensorflow/c/c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// -------------------------------------------------------------------------- -// C API for Rendezvous. -// NOTE: This API is unstable and almost certainly will change in the near -// future. -// -// Custom rendezvous allows for custom implementations of Recv call. -// -// Users wishing to create custom rendezvous objects should call -// TF_NewRemoteRendezvousBuilder and pass returned TF_RemoteRendezvousBuilder -// to to TF_NewServerFactory. - -typedef struct TF_RemoteRendezvousBuilder TF_RemoteRendezvousBuilder; -typedef struct TF_ParsedKey TF_ParsedKey; -typedef struct TF_RendezvousArgs TF_RendezvousArgs; -typedef struct TF_RendezvousDoneCallback TF_RendezvousDoneCallback; - -// Creates a new TF_RemoteRendezvousBuilder instance. -// Rendezvous instances will forward calls to init_function, -// receive_from_remote_async_function and delete_function passed here. -// -// Note that receive_from_remote_async_function implementation must call -// TF_Done with the TF_DoneCallback passed as an argument. -TF_CAPI_EXPORT extern TF_RemoteRendezvousBuilder* TF_NewRemoteRendezvousBuilder( - void* (*init_function)(void* server_context), - void (*receive_from_remote_async_function)(TF_ParsedKey*, - TF_RendezvousArgs*, - TF_RendezvousDoneCallback*, - void* context), - void (*delete_function)(void* context)); - -// Deletes TF_RemoteRendezvousBuilder instances. -TF_CAPI_EXPORT extern void TF_DeleteRemoteRendezvousBuilder( - TF_RemoteRendezvousBuilder* rendezvous_builder); - -// Calls TF_DoneCallback and destroys callback instance and -// TF_DoneCallback members except `tensor` and `status`. Caller is -// responsible for deleting `tensor` and `status` after TF_Done returns. -TF_CAPI_EXPORT extern void TF_RendezvousDone( - TF_RendezvousDoneCallback* callback); - -#ifdef __cplusplus -} /* end extern "C" */ -#endif -#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_H_ diff --git a/tensorflow/c/experimental/rendezvous_internal.h b/tensorflow/c/experimental/rendezvous_internal.h deleted file mode 100644 index f06686023e6..00000000000 --- a/tensorflow/c/experimental/rendezvous_internal.h +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_ -#define TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_ - -#include - -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/experimental/rendezvous.h" -#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" -#include "tensorflow/core/distributed_runtime/worker_env.h" -#include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/framework/rendezvous.h" -#include "tensorflow/core/platform/macros.h" - -struct TF_ParsedKey { - // char* members might not be null-terminated. - const char* src_device; - size_t src_device_len; - const char* dst_device; - size_t dst_device_len; - const char* full_key; - size_t full_key_len; -}; - -struct TF_AllocatorAttributes { - bool on_host; - bool nic_compatible; - // NOTE: The upper 8 bits of the value are reserved for - // device-specific uses. Implementors of a device can interpret these - // upper 8 bits in device-specific ways, and ops implemented for those - // devices are responsible for setting those 8 bits appropriately. - tensorflow::uint32 value = 0; - // EXPERIMENTAL: If this is greater than zero, then allocation is delegated to - // a named special-purpose allocator on the same device. - tensorflow::int32 scope_id = 0; -}; - -struct TF_DeviceContext { - ::tensorflow::DeviceContext* context; -}; - -struct TF_RendezvousArgs { - const TF_DeviceContext* device_context; - const TF_AllocatorAttributes* alloc_attrs; -}; - -struct TF_RendezvousDoneCallback { - ::tensorflow::Rendezvous::DoneCallback done_callback; - - // TODO(annarev): figure out if we should also support sent_args. - const TF_RendezvousArgs* recv_args; - TF_Tensor* tensor = nullptr; - TF_Status* status; - bool dead; -}; - -struct TF_RemoteRendezvousBuilder { - void* (*init_function)(void* server_context); - void (*receive_from_remote_async_function)(TF_ParsedKey*, TF_RendezvousArgs*, - TF_RendezvousDoneCallback*, - void* context); - void (*delete_function)(void* context); - void* server_context; -}; - -namespace tensorflow { - -class CRemoteRendezvous : public BaseRemoteRendezvous { - public: - CRemoteRendezvous(const WorkerEnv* env, int64 step_id, - void (*receive_from_remote_async_function)( - TF_ParsedKey*, TF_RendezvousArgs*, - TF_RendezvousDoneCallback*, void* context), - void (*delete_function)(void* context), - void* server_context); - - void SetContext(void* context) { context_ = context; } - - protected: - void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, - const Rendezvous::Args& args, - DoneCallback done) override; - - private: - ~CRemoteRendezvous() override; - - void (*receive_from_remote_async_function_)(TF_ParsedKey*, TF_RendezvousArgs*, - TF_RendezvousDoneCallback*, - void* context); - void (*delete_function_)(void* context); - void* context_; - TF_DISALLOW_COPY_AND_ASSIGN(CRemoteRendezvous); -}; - -class CRendezvousMgr : public BaseRendezvousMgr { - public: - CRendezvousMgr(const WorkerEnv* env, - const TF_RemoteRendezvousBuilder* rendezvous_builder) - : BaseRendezvousMgr(env), rendezvous_builder_(rendezvous_builder) {} - - protected: - BaseRemoteRendezvous* Create(int64 step_id, - const WorkerEnv* worker_env) override { - auto* rendezvous = new CRemoteRendezvous( - worker_env, step_id, - rendezvous_builder_->receive_from_remote_async_function, - rendezvous_builder_->delete_function, - rendezvous_builder_->server_context); - - rendezvous->SetContext(rendezvous_builder_->init_function( - rendezvous_builder_->server_context)); - return rendezvous; - } - - private: - const TF_RemoteRendezvousBuilder* rendezvous_builder_; - TF_DISALLOW_COPY_AND_ASSIGN(CRendezvousMgr); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_C_EXPERIMENTAL_RENDEZVOUS_INTERNAL_H_ diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 5931e229e28..b2e432782de 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -26,6 +26,7 @@ cc_library( ":function_metadata", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", + "@com_google_absl//absl/types:span", ], ) @@ -113,8 +114,23 @@ cc_library( deps = [ ":concrete_function", ":saved_model_api", + ":saved_model_utils", + "//tensorflow/c:tensor_interface", "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core/ops:restore_ops", + "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", + "//tensorflow/c/experimental/saved_model/core/revived_types:variable", + "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:constants", + "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:tensor_handle", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -131,6 +147,7 @@ cc_library( "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -199,6 +216,23 @@ tf_cc_test( ], ) +tf_cc_test( + name = "signature_flattening_test", + srcs = [ + "signature_flattening_test.cc", + ], + deps = [ + ":saved_model_utils", + "//tensorflow/c/experimental/saved_model/core:tf_concrete_function_test_protos", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime/eager:core", + ], +) + tf_cc_test( name = "tf_concrete_function_loading_test", srcs = [ diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index 2cc627bcf27..da3a64b91a3 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" @@ -38,10 +39,9 @@ class ConcreteFunction { virtual ~ConcreteFunction() = default; // This method returns the "Call" Op used to execute the function. - virtual Status GetCallOp(ImmediateOpPtr* out) = 0; + virtual Status GetCallOp(absl::Span inputs, + ImmediateOpPtr* out) = 0; - virtual const std::vector& GetCaptures() - const = 0; virtual const FunctionMetadata& GetFunctionMetadata() const = 0; }; diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 8bb15674db0..2b883618c87 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -69,6 +69,7 @@ cc_library( ], deps = [ ":tensorhandle_convertible", + "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", @@ -77,5 +78,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/eager:context", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc index aa6f0e7205e..f734f9eca66 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" @@ -60,16 +62,12 @@ Status TFConcreteFunction::Create( return Status(); } -const std::vector& -TFConcreteFunction::GetCaptures() const { - return captures_; -} - const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const { return metadata_; } -Status TFConcreteFunction::GetCallOp(ImmediateOpPtr* out) { +Status TFConcreteFunction::GetCallOp( + absl::Span inputs, ImmediateOpPtr* out) { out->reset(ctx_->CreateOperation()); // In eager mode, TF2 python executes functions by constructing an op with // the name of the functiondef: @@ -81,6 +79,16 @@ Status TFConcreteFunction::GetCallOp(ImmediateOpPtr* out) { // PartitionedCallOp for compatibility with "tooling that assumes functions in // graphs are PartitionedCallOps". TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); + + // Adding the user-provided inputs to the function. + TF_RETURN_IF_ERROR((*out)->AddInputList(inputs)); + + absl::Span captures( + reinterpret_cast(captures_.data()), + captures_.size()); + + // Adding the captures of the function. + TF_RETURN_IF_ERROR((*out)->AddInputList(captures)); return Status(); } diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h index 71c8322414d..d38f3546f91 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h @@ -58,10 +58,8 @@ class TFConcreteFunction : public ConcreteFunction { std::unique_ptr* out); // This method returns the "Call" Op used to execute the function. - Status GetCallOp(ImmediateOpPtr* out) override; - - const std::vector& GetCaptures() - const override; + Status GetCallOp(absl::Span inputs, + ImmediateOpPtr* out) override; const FunctionMetadata& GetFunctionMetadata() const override; diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 2037c4886de..0d97741d7f0 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/struct.pb.h" @@ -36,52 +37,8 @@ namespace tensorflow { namespace internal { namespace { -// This returns the size of `tf.nest.flatten(value)`, on values that are -// used in tf.function's input_signatures. -int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) { - // This follows the logic from - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775 - switch (value.kind_case()) { - case StructuredValue::kDictValue: { - const DictValue& dict = value.dict_value(); - int size = 0; - for (const auto& field : dict.fields()) { - size += FlattenedSize(field.second, status); - } - return size; - } - case StructuredValue::kTupleValue: { - const TupleValue& tuple = value.tuple_value(); - int size = 0; - for (const StructuredValue& value : tuple.values()) { - size += FlattenedSize(value, status); - } - return size; - } - case StructuredValue::kListValue: { - const ListValue& list = value.list_value(); - int size = 0; - for (const StructuredValue& value : list.values()) { - size += FlattenedSize(value, status); - } - return size; - } - case StructuredValue::kTensorSpecValue: { - return 1; - } - case StructuredValue::kNoneValue: { - // Base case: do nothing. - // This arises, for example, as the top-level object of an output - // signature when there are no return values. - return 0; - } - default: { - status->Update(errors::Internal("Unhandled structured value kind ", - value.kind_case())); - return 0; - } - } -} +using StructuredValueDictEntry = + protobuf::MapPair; // Perform some basic sanity checks on SavedConcreteFunction's input and // output signatures with respect to the corresponding FunctionDef's input @@ -111,34 +68,34 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef( // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979 const std::string& name = function_def->signature().name(); + const StructuredValue& input_signature = saved_concrete_function.canonicalized_input_signature(); - Status status; - int input_signature_size = FlattenedSize(input_signature, &status); - TF_RETURN_IF_ERROR(status); - if (input_signature_size + saved_concrete_function.bound_inputs_size() != + std::vector input_specs; + TF_RETURN_IF_ERROR(FlattenSignature(input_signature, &input_specs)); + if (input_specs.size() + saved_concrete_function.bound_inputs_size() != function_def->signature().input_arg_size()) { return errors::FailedPrecondition( "FunctionDef ", name, " has ", function_def->signature().input_arg_size(), - " inputs, but the SavedConcreteFunction has ", input_signature_size, + " inputs, but the SavedConcreteFunction has ", input_specs.size(), " flattened user inputs and ", saved_concrete_function.bound_inputs_size(), " captured inputs."); } const StructuredValue& output_signature = saved_concrete_function.output_signature(); - int output_signature_size = FlattenedSize(output_signature, &status); - TF_RETURN_IF_ERROR(status); - if (output_signature_size != function_def->signature().output_arg_size()) { + std::vector output_specs; + TF_RETURN_IF_ERROR(FlattenSignature(output_signature, &output_specs)); + if (output_specs.size() != function_def->signature().output_arg_size()) { return errors::FailedPrecondition( "FunctionDef ", name, " has ", function_def->signature().output_arg_size(), - " outputs, but the SavedConcreteFunction has ", output_signature_size, + " outputs, but the SavedConcreteFunction has ", output_specs.size(), " flattened outputs."); } - return status; + return Status(); } } // namespace @@ -197,6 +154,62 @@ Status LoadTFConcreteFunction( out); } +Status FlattenSignature(const StructuredValue& signature, + std::vector* flattened_specs) { + // This follows the logic from + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775 + switch (signature.kind_case()) { + case StructuredValue::kDictValue: { + // Dictionaries must be sorted in order of keys + const DictValue& dict = signature.dict_value(); + std::vector entries; + entries.reserve(dict.fields_size()); + for (const auto& field : dict.fields()) { + entries.push_back(&field); + } + + std::sort(entries.begin(), entries.end(), + [](const StructuredValueDictEntry* x, + const StructuredValueDictEntry* y) { + return x->first < y->first; + }); + + for (const auto& entry : entries) { + TF_RETURN_IF_ERROR(FlattenSignature(entry->second, flattened_specs)); + } + return Status(); + } + case StructuredValue::kTupleValue: { + const TupleValue& tuple = signature.tuple_value(); + for (const StructuredValue& value : tuple.values()) { + TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs)); + } + return Status(); + } + case StructuredValue::kListValue: { + const ListValue& list = signature.list_value(); + for (const StructuredValue& value : list.values()) { + TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs)); + } + return Status(); + } + case StructuredValue::kTensorSpecValue: { + flattened_specs->push_back(&signature.tensor_spec_value()); + return Status(); + } + case StructuredValue::kNoneValue: { + // Base case: do nothing. + // This arises, for example, as the top-level object of an output + // signature when there are no return values. + return Status(); + } + default: { + return errors::Internal("Unhandled structured value kind ", + signature.kind_case()); + } + } +} + const SavedObject* FindNodeAtPath(StringPiece path, const SavedObjectGraph& object_graph) { const auto& nodes = object_graph.nodes(); diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index 57f30afa91b..68bfbe32222 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" namespace tensorflow { namespace internal { @@ -59,10 +60,17 @@ Status LoadTFConcreteFunction( captured_objects, ImmediateExecutionContext* ctx, std::unique_ptr* out); -// Find the SavedObject in `object_graph` at location `path`. `path` must be a -// dot-delimited string of object names relative to the root object. If no -// object is found, returns nullptr. Callers must ensure `object_graph` outlives -// the returned pointer. +// Flattens `signature` into a vector of TensorSpecProto pointers back into +// `signature`. `signature` must outlive flattened_specs. `signature` must also +// be the input or output signature of a SavedConcreteFunction (i.e. "nested +// structures of tensorspecs"). +Status FlattenSignature(const StructuredValue& signature, + std::vector* flattened_specs); + +// Find the SavedObject in `object_graph` at location `path`. `path` must be +// a dot-delimited string of object names relative to the root object. If no +// object is found, returns nullptr. Callers must ensure `object_graph` +// outlives the returned pointer. const SavedObject* FindNodeAtPath(StringPiece path, const SavedObjectGraph& object_graph); diff --git a/tensorflow/c/experimental/saved_model/core/signature_flattening_test.cc b/tensorflow/c/experimental/saved_model/core/signature_flattening_test.cc new file mode 100644 index 00000000000..9ee495f524a --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/signature_flattening_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" +#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { +namespace { + +// Validates names, shapes, and dtypes of two tensorspecprotos are equivalent. +bool TensorSpecsAreEqual(const TensorSpecProto& spec, + const std::string& expected_name, + const PartialTensorShape& expected_shape, + DataType expected_dtype) { + return spec.name() == expected_name && + PartialTensorShape(spec.shape()).IsIdenticalTo(expected_shape) && + spec.dtype() == expected_dtype; +} + +// This tests the common case for a tf.function w/o inputs. This ends up +// being serialized as a tuple of an empty tuple + empty dictionary +// (corresponding to the args, kwargs) of the function. +TEST(SignatureFlatteningTest, ZeroArgInputSignature) { + std::vector flattened; + StructuredValue value = testing::ZeroArgInputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 0); +} + +// This tests the common case for a tf.function w/o outputs. This ends up +// being serialized as a "NoneValue". +TEST(SignatureFlatteningTest, ZeroRetOutputSignature) { + std::vector flattened; + StructuredValue value = testing::ZeroReturnOutputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 0); +} + +TEST(SignatureFlatteningTest, SingleArgInputSignature) { + std::vector flattened; + StructuredValue value = testing::SingleArgInputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 1); + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0], + /* expected_name = */ "x", + /* expected_shape = */ {1, 10}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[0]->DebugString(); +} + +TEST(SignatureFlatteningTest, SingleReturnOutputSignature) { + std::vector flattened; + StructuredValue value = testing::SingleReturnOutputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 1); + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0], + /* expected_name = */ "", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[0]->DebugString(); +} + +TEST(SignatureFlatteningTest, ThreeArgInputSignature) { + std::vector flattened; + StructuredValue value = testing::ThreeArgInputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 3); + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0], + /* expected_name = */ "x", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[0]->DebugString(); + + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[1], + /* expected_name = */ "y", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[1]->DebugString(); + + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[2], + /* expected_name = */ "z", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[2]->DebugString(); +} + +// This test has an exotic outputsignature of tuple of a +// dictionary, tensor +TEST(SignatureFlatteningTest, ThreeReturnOutputSignature) { + std::vector flattened; + StructuredValue value = testing::ThreeReturnOutputSignature(); + TF_EXPECT_OK(internal::FlattenSignature(value, &flattened)); + EXPECT_EQ(flattened.size(), 3); + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0], + /* expected_name = */ "0/a", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[0]->DebugString(); + + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[1], + /* expected_name = */ "0/b", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[1]->DebugString(); + + EXPECT_TRUE(TensorSpecsAreEqual(*flattened[2], + /* expected_name = */ "1", + /* expected_shape = */ {1}, + /* expected_dtype = */ DT_FLOAT)) + << "Expected " << flattened[2]->DebugString(); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index 225ba1db9f4..0f0102be857 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -15,47 +15,364 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h" +#include #include +#include #include #include +#include "absl/algorithm/container.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" +#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" +#include "tensorflow/cc/saved_model/bundle_v2.h" +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_model.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" namespace tensorflow { +// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary +using FunctionDefMap = + std::unordered_map; + +// Maps from a Nodedef's name to its corresponding AttrValues, for a given +// Graphdef +using NodeAttrMap = + std::unordered_map; + +// Maps from Node ID to an "Revived Object" implementing +// "TensorHandleConvertible" +using RevivedObjectMap = + std::unordered_map>; + +// Maps from a functiondef's name to the corresponding "TFConcreteFunction" +using ConcreteFunctionMap = + std::unordered_map>; + +namespace { + +Status ConstantFromSavedConstant( + ImmediateExecutionContext* ctx, + const tensorflow::SavedConstant& saved_constant, + const NodeAttrMap& node_attr_map, std::unique_ptr* output) { + const std::string& const_op_name = saved_constant.operation(); + const auto& node_name_and_attrs = node_attr_map.find(const_op_name); + if (node_name_and_attrs == node_attr_map.end()) { + return errors::FailedPrecondition( + "Unable to find Const operation with name'", const_op_name, + "' in SavedModel graphdef"); + } + const AttrValueMap* attrs = node_name_and_attrs->second; + const auto& attr_name_and_value = attrs->find("value"); + if (attr_name_and_value == attrs->end()) { + return errors::FailedPrecondition("Unable to find Const operation '", + const_op_name, "'s value attribute"); + } + const TensorProto& tensor_proto = attr_name_and_value->second.tensor(); + return internal::TensorProtoToConstant(ctx, tensor_proto, output); +} + +// Restores all non-function objects in the SavedModel's object graph. +// This function walks through the metagraph's saved object graph, and +// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and +// SavedResources. These are returned via the `out` parameter. +Status ReviveObjects( + const MetaGraphDef& metagraph, ImmediateExecutionContext* context, + std::unordered_map>* + revived_objects) { + // This is needed to restore "Constant" nodes by looking up their + // "Value" attribute. + NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def()); + + // Iterate through all the saved objects, restoring objects as we go. + // We don't recreate functions until all other objects have been created. + for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) { + const SavedObject& node = metagraph.object_graph_def().nodes(i); + if (node.kind_case() == SavedObject::kVariable) { + std::unique_ptr variable; + TF_RETURN_IF_ERROR( + internal::LoadSavedVariable(context, node.variable(), &variable)); + (*revived_objects)[i] = std::move(variable); + } else if (node.kind_case() == SavedObject::kConstant) { + std::unique_ptr constant; + TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(), + node_attr_map, &constant)); + (*revived_objects)[i] = std::move(constant); + } else if (node.kind_case() == SavedObject::kAsset) { + // TODO(bmzhao): Implement Asset C++ class. This should be just recreating + // the full path to the asset file: + // https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396 + // and storing it as a string tensor: + // https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325 + return errors::Unimplemented("SavedAsset loading is not implemented yet"); + } else if (node.kind_case() == SavedObject::kResource) { + // TODO(bmzhao): Figure out how resource loading works and implement it + return errors::Unimplemented( + "SavedResource loading is not implemented yet"); + } + } + return Status(); +} + +Status ReviveFunctions(const MetaGraphDef& metagraph, + const RevivedObjectMap& revived_objects, + ImmediateExecutionContext* context, + ConcreteFunctionMap* restored_functions) { + const FunctionDefMap function_def_map = + internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library()); + + // Iterate through all objects, only examining functions. + for (const SavedObject& node : metagraph.object_graph_def().nodes()) { + if (node.kind_case() == SavedObject::kBareConcreteFunction) { + const std::string& function_name = + node.bare_concrete_function().concrete_function_name(); + + const SavedConcreteFunction& saved_concrete_function = + metagraph.object_graph_def().concrete_functions().at(function_name); + + const FunctionDef* function_def = function_def_map.at(function_name); + std::unique_ptr concrete_function; + TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction( + saved_concrete_function, function_def, revived_objects, context, + &concrete_function)); + (*restored_functions)[function_name] = std::move(concrete_function); + } else if (node.kind_case() == SavedObject::kFunction) { + // We only allow loading functions that have an annotated input signature, + // which means there is 1:1 correspondence between tf.function + // <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is + // the same restriction that MLIR has: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707 + const SavedFunction& saved_function = node.function(); + if (saved_function.concrete_functions_size() != 1) { + return errors::FailedPrecondition( + "Only tf.functions annotated with an input signature are supported " + "by SavedModelAPI. This means that there should only be a single " + "ConcreteFunction per tf.function"); + } + const std::string& function_name = saved_function.concrete_functions(0); + const SavedConcreteFunction& saved_concrete_function = + metagraph.object_graph_def().concrete_functions().at(function_name); + + const FunctionDef* function_def = function_def_map.at(function_name); + + std::unique_ptr concrete_function; + TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction( + saved_concrete_function, function_def, revived_objects, context, + &concrete_function)); + (*restored_functions)[function_name] = std::move(concrete_function); + } + } + return Status(); +} + +const TrackableObjectGraph::TrackableObject::SerializedTensor* +FindSerializedTensorInTrackable( + const TrackableObjectGraph::TrackableObject& trackable_object, + absl::string_view name) { + for (const auto& maybe_serialized_tensor : trackable_object.attributes()) { + if (maybe_serialized_tensor.name() == name) { + return &maybe_serialized_tensor; + } + } + return nullptr; +} + +// This function reads the Checkpoint embedded in the SavedModel, and calls the +// appropriate Restore ops on each of the variables. +// Note(bmzhao): Conceptually, objects that contain checkpointable state +// implement the "_gather_saveables_for_checkpoint" method +// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983 +// which returns a dict of string key -> EITHER: +// 1. python callable (taking a checkpoint key) returning SaveableObject OR +// 2. variable (partitioned/resource/reference or otherwise) +// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58. +// The string key becomes the "name" attribute of the SerializedTensor proto +// in the TrackableObjectGraph, +// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26 +// And the checkpoint_key is a globally unique string derived from this name: +// https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241 +// SaveableObjects model the information needed to pass to the SaveV2/RestoreV2 +// ops via their SaveSpec members +// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21, +// which contain the "real" checkpoint keys into the TensorBundle SSTable. +// They also contain the logic needed to take the restored tensors from +// RestoreV2 and load them back into the "object" they came from via their +// overridden "restore" method: +// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85 +Status RestoreCheckpoint(SavedModelV2Bundle* bundle, + const RevivedObjectMap& revived_objects, + const std::string& directory, + ImmediateExecutionContext* context) { + // TODO(bmzhao): Batch up all the restores into a single restore op per + // device, following logic in MultiDeviceSaver. + TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore( + [&revived_objects, &directory, context, bundle]( + int node, const TrackableObjectGraph::TrackableObject& trackable) { + if (bundle->saved_object_graph().nodes(node).kind_case() != + SavedObject::kVariable) { + // TODO(bmzhao): This requires using the newly added Save/Restore + // functions from + // https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c + LOG(WARNING) << "Restoring non-variable objects has not been " + "implemented yet. (Kind=" + << bundle->saved_object_graph().nodes(node).kind_case() + << ")"; + return Status::OK(); + } + + Variable* variable = + down_cast(revived_objects.at(node).get()); + + // Restore the tensor's value from the checkpoint + const TrackableObjectGraph::TrackableObject::SerializedTensor* + attribute = + FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE"); + if (attribute == nullptr) { + return errors::FailedPrecondition( + "Could not find SerializedTensor with name VARIABLE_VALUE for " + "saved variable"); + } + + const std::string& checkpoint_key = attribute->checkpoint_key(); + std::string variables_path_prefix = + io::JoinPath(directory, kSavedModelVariablesDirectory, + kSavedModelVariablesFilename); + ImmediateTensorHandlePtr restored_output; + TF_RETURN_IF_ERROR(internal::SingleRestore( + context, variables_path_prefix, checkpoint_key, variable->dtype(), + &restored_output)); + + // Assign the restored tensor's value to the variable + return variable->Assign(restored_output.get()); + })); + + return Status(); +} + +} // namespace + Status TFSavedModelAPI::GetFunction(const std::string& function_path, ConcreteFunction** function) { - // TODO(bmzhao): Add support for retrieving a function. - return errors::Unimplemented( - "Retrieving functions is unimplemented currently"); + const SavedObject* object = + internal::FindNodeAtPath(function_path, bundle_.saved_object_graph()); + if (object == nullptr) { + return errors::NotFound("No saved object found at path ", function_path); + } + + if (object->kind_case() == SavedObject::kBareConcreteFunction) { + *function = + concrete_functions_ + .at(object->bare_concrete_function().concrete_function_name()) + .get(); + } else if (object->kind_case() == SavedObject::kFunction) { + *function = + concrete_functions_.at(object->function().concrete_functions(0)).get(); + } else { + return errors::InvalidArgument(function_path, + " is not a path to a Function."); + } + + return Status(); } Status TFSavedModelAPI::GetSignatureDefFunction( const std::string& signature_def_key, ConcreteFunction** function) { // TODO(bmzhao): Add support for retrieving a signaturedef function. return errors::Unimplemented( - "Retrieving functions is unimplemented currently"); + "Retrieving SignatureDef functions is unimplemented currently"); } std::vector TFSavedModelAPI::ListFunctions() { std::vector result; - result.reserve(functions_.size()); - for (ConcreteFunction& function : functions_) { - result.push_back(&function); + result.reserve(concrete_functions_.size()); + for (auto& index_and_function : concrete_functions_) { + result.push_back(index_and_function.second.get()); } return result; } +TFSavedModelAPI::TFSavedModelAPI( + const std::string& directory, SavedModelV2Bundle bundle, + std::unordered_map> + revived_objects, + std::unordered_map> + concrete_functions) + : directory_(directory), + bundle_(std::move(bundle)), + revived_objects_(std::move(revived_objects)), + concrete_functions_(std::move(concrete_functions)) {} + Status TFSavedModelAPI::Load( const std::string& directory, const absl::optional>& tags, ImmediateExecutionContext* context, std::unique_ptr* out) { - // TODO(bmzhao): Add support for loading a TFSavedModelImpl. - return errors::Unimplemented( - "TFSavedModelAPIImpl loading is unimplemented currently"); + // TODO(bmzhao): Add support for loading a TF1 SavedModel. + if (tags) { + return errors::Unimplemented( + "Loading saved models with explicit tags will be supported in the " + "future"); + } + + SavedModelV2Bundle bundle; + TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle)); + + // TODO(bmzhao): Mangle loaded function names so that different + // models loaded in the same runtime Context don't clobber eachother. + // This occurs in python here: + // https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454 + + RevivedObjectMap revived_objects; + TF_RETURN_IF_ERROR( + ReviveObjects(bundle.meta_graph_def(), context, &revived_objects)); + + // TODO(bmzhao): When we later add support for loading resources, we need to + // handle the case where materializing a function's captures requires invoking + // other functions. This occurs when retrieving the resource handle for a + // TrackableResource: + // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240 + // https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233 + // This requires restoring functions in a topological sort order by capture + // dependencies. + ConcreteFunctionMap function_map; + TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects, + context, &function_map)); + + TF_RETURN_IF_ERROR( + RestoreCheckpoint(&bundle, revived_objects, directory, context)); + + out->reset(new TFSavedModelAPI(directory, std::move(bundle), + std::move(revived_objects), + std::move(function_map))); + return Status(); } } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h index cc631a9f3ae..fc8e738e86f 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -16,14 +16,19 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_ +#include #include +#include #include #include #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" +#include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { @@ -63,8 +68,19 @@ class TFSavedModelAPI : public SavedModelAPI { ~TFSavedModelAPI() override = default; private: - TFSavedModelAPI() = default; - std::vector functions_; + TFSavedModelAPI( + const std::string& directory, SavedModelV2Bundle bundle, + std::unordered_map> + revived_objects, + std::unordered_map> + concrete_functions); + + std::string directory_; + SavedModelV2Bundle bundle_; + std::unordered_map> + revived_objects_; + std::unordered_map> + concrete_functions_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index b22718dfd04..323298c5fc1 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -38,16 +38,17 @@ cc_library( ":concrete_function_type", ":function_metadata", ":function_metadata_type", - ":tensorhandle_list", - ":tensorhandle_list_type", "//tensorflow/c:c_api_macros", "//tensorflow/c:tf_status_internal", + "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:tfe_op_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/c/experimental/saved_model/core:concrete_function", "//tensorflow/c/experimental/saved_model/core:function_metadata", "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", ], ) @@ -164,38 +165,6 @@ cc_library( ], ) -cc_library( - name = "tensorhandle_list", - srcs = [ - "tensorhandle_list.cc", - ], - hdrs = [ - "//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h", - ], - copts = tf_copts(), - visibility = [ - "//tensorflow/c/experimental/saved_model/public:__pkg__", - ], - deps = [ - ":tensorhandle_list_type", - "//tensorflow/c:c_api_macros", - "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:immediate_execution_tensor_handle", - "//tensorflow/c/eager:tfe_tensorhandle_internal", - ], -) - -cc_library( - name = "tensorhandle_list_type", - hdrs = [ - "tensorhandle_list_type.h", - ], - deps = [ - "//tensorflow/c:conversion_macros", - "//tensorflow/c/eager:immediate_execution_tensor_handle", - ], -) - tf_cc_test( name = "saved_model_api_test", size = "small", @@ -213,7 +182,6 @@ tf_cc_test( "//tensorflow/c/eager:c_api_test_util", "//tensorflow/c/experimental/saved_model/public:concrete_function", "//tensorflow/c/experimental/saved_model/public:saved_model_api", - "//tensorflow/c/experimental/saved_model/public:tensorhandle_list", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index 12d49212a88..65c6eca5623 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -15,13 +15,15 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" -#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" #include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/platform/status.h" @@ -32,15 +34,18 @@ TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata(TF_ConcreteFunction* func) { &tensorflow::unwrap(func)->GetFunctionMetadata())); } -const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( - TF_ConcreteFunction* func) { - return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures()); -} - TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func, + TFE_TensorHandle** inputs, int num_inputs, TF_Status* status) { - tensorflow::ImmediateOpPtr call_op(nullptr); - status->status = tensorflow::unwrap(func)->GetCallOp(&call_op); + tensorflow::ImmediateOpPtr call_op; + absl::Span input_span( + reinterpret_cast( + tensorflow::unwrap(inputs)), + static_cast(num_inputs)); + status->status = tensorflow::unwrap(func)->GetCallOp(input_span, &call_op); + if (!status->status.ok()) { + return nullptr; + } return tensorflow::wrap(call_op.release()); } diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc index aa0b00ab847..e58b232f9c9 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api_test.cc @@ -16,10 +16,14 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" #include +#include #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" @@ -92,12 +96,42 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) { TF_SavedModel* saved_model = TF_LoadSavedModel(model_dir.c_str(), ctx, status); - // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. - // That unblocks writing other tests that require a TF_SavedModel*, - // like loading a ConcreteFunction. This test at least checks that the - // C API builds and can be minimally run. - EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + TF_ConcreteFunction* compute_fn = + TF_GetSavedModelConcreteFunction(saved_model, "compute", status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + std::vector compute_fn_inputs; + TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f); + TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f); + compute_fn_inputs.push_back(input_a); + compute_fn_inputs.push_back(input_b); + + TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp( + compute_fn, compute_fn_inputs.data(), compute_fn_inputs.size(), status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + // TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many + // inputs + outputs a function has. + TFE_TensorHandle* compute_fn_outputs[1] = {nullptr}; + int num_retvals = 1; + + TFE_Execute(compute_fn_op, &compute_fn_outputs[0], &num_retvals, status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status); + EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + + EXPECT_EQ(TF_NumDims(result), 0); + float output_value = *static_cast(TF_TensorData(result)); + // (1 + 2) * (2 + 1) / 3 + 5 should be 8 + EXPECT_FLOAT_EQ(output_value, 8.0); + + TF_DeleteTensor(result); + TFE_DeleteTensorHandle(compute_fn_outputs[0]); + TFE_DeleteTensorHandle(input_a); + TFE_DeleteTensorHandle(input_b); + TFE_DeleteOp(compute_fn_op); TF_DeleteSavedModel(saved_model); TF_DeleteStatus(status); TFE_DeleteContext(ctx); diff --git a/tensorflow/c/experimental/saved_model/public/BUILD b/tensorflow/c/experimental/saved_model/public/BUILD index 0cfa0a2c005..af65e05e7f6 100644 --- a/tensorflow/c/experimental/saved_model/public/BUILD +++ b/tensorflow/c/experimental/saved_model/public/BUILD @@ -24,7 +24,6 @@ exports_files( "concrete_function_list.h", "function_metadata.h", "saved_model_api.h", - "tensorhandle_list.h", ], visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], ) @@ -40,7 +39,6 @@ cc_library( ":concrete_function_list", ":function_metadata", ":saved_model_api", - ":tensorhandle_list", ], ) @@ -63,8 +61,3 @@ alias( name = "saved_model_api", actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api", ) - -alias( - name = "tensorhandle_list", - actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list", -) diff --git a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h index aae95a5477c..30f533f140a 100644 --- a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" -#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" // IWYU pragma: end_exports #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index 944ddecea16..ee5292294d6 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h" -#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" #ifdef __cplusplus extern "C" { @@ -35,13 +34,15 @@ typedef struct TF_ConcreteFunction TF_ConcreteFunction; TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( TF_ConcreteFunction* func); -// Returns a list of TensorHandles implicitly captured by this function. -TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( - TF_ConcreteFunction* func); - -// Returns a TFE_Op suitable for executing this function. +// Returns a TFE_Op suitable for executing this function. Caller must provide +// all function inputs in `inputs`, and must not add any additional inputs on +// the returned op. (i.e. don't call TFE_OpAddInput or TFE_OpAddInputList). +// The caller is responsible for deleting the returned TFE_Op. If op +// construction fails, `status` will be non-OK and the returned pointer will be +// null. TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( - TF_ConcreteFunction* func, TF_Status* status); + TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs, + TF_Status* status); #ifdef __cplusplus } // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h b/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h deleted file mode 100644 index a1e88db3474..00000000000 --- a/tensorflow/c/experimental/saved_model/public/tensorhandle_list.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ -#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ - -#include - -#include "tensorflow/c/c_api_macros.h" -#include "tensorflow/c/eager/c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// An opaque type that is acts like a list of TF_ConcreteFunction pointers. -typedef struct TF_TensorHandleList TF_TensorHandleList; - -// Returns the size of `list`. -TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize( - const TF_TensorHandleList* list); - -// Returns the `i`th TFE_TensorHandle in the list. -TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet( - const TF_TensorHandleList* list, int i); - -#ifdef __cplusplus -} // end extern "C" -#endif // __cplusplus - -#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_ diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 3021a38e888..20a6c5117cf 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -97,6 +97,11 @@ void TF_KernelBuilder_HostMemory(TF_KernelBuilder* kernel_builder, kernel_builder->cc_builder->HostMemory(arg_name); } +void TF_KernelBuilder_Priority(TF_KernelBuilder* kernel_builder, + int32_t priority_number) { + kernel_builder->cc_builder->Priority(priority_number); +} + namespace tensorflow { namespace { @@ -234,6 +239,14 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) { DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType) DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t) +TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) { + auto* cc_ctx = reinterpret_cast(ctx); + TF_StringView string_view_of_name; + string_view_of_name.data = cc_ctx->def().name().data(); + string_view_of_name.len = cc_ctx->def().name().length(); + return string_view_of_name; +} + TF_DataType TF_ExpectedOutputDataType(TF_OpKernelContext* ctx, int i) { auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); return static_cast(cc_ctx->expected_output_dtype(i)); @@ -266,4 +279,4 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index, return nullptr; } return tf_tensor; -} +} \ No newline at end of file diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 084717c1d9e..c7138a39c73 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/c/c_api.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" @@ -107,6 +108,10 @@ TF_CAPI_EXPORT extern void TF_KernelBuilder_TypeConstraint( TF_CAPI_EXPORT extern void TF_KernelBuilder_HostMemory( TF_KernelBuilder* kernel_builder, const char* arg_name); +// Specify a priority number for this kernel. +TF_CAPI_EXPORT extern void TF_KernelBuilder_Priority( + TF_KernelBuilder* kernel_builder, int32_t priority_number); + // Register the given kernel builder with the TensorFlow runtime. If // registration fails, the given status will be populated. // @@ -180,6 +185,10 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32( TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val, TF_Status* status); +// Returns the unique operation name for this OpKernel. +TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName( + TF_OpKernelConstruction* ctx); + // Allocates Tensor for output at given index. Caller takes ownership of // returned TF_Tensor and should deallocate it using TF_DeleteTensor(tensor). // diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD index 770352c62c1..008d2ee2d67 100644 --- a/tensorflow/c/kernels/BUILD +++ b/tensorflow/c/kernels/BUILD @@ -24,6 +24,21 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "summary_op", + prefix = "summary_op", + deps = [ + "//tensorflow/c:kernels", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_tensor", + "//tensorflow/c/kernels:tensor_shape_utils", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//third_party/eigen3", + ], +) + tf_gen_op_libs( op_lib_names = ["bitcast"], deps = [ @@ -35,6 +50,15 @@ tf_gen_op_libs( ], ) +tf_gen_op_libs( + op_lib_names = ["summary"], + deps = [ + "//tensorflow/c:ops", + "//tensorflow/c:tf_status", + "//tensorflow/core:lib", + ], +) + tf_cc_test( name = "bitcast_op_test", srcs = ["bitcast_op_test.cc"], @@ -48,6 +72,62 @@ tf_cc_test( ], ) +tf_cc_test( + name = "summary_op_test", + srcs = ["summary_op_test.cc"], + deps = [ + ":summary_op", + "//tensorflow/c:kernels", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "summary_op_benchmark_test", + size = "small", + srcs = ["summary_op_benchmark_test.cc"], + deps = [ + ":summary_op", + "//tensorflow/c:kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "tensor_shape_utils", + srcs = ["tensor_shape_utils.cc"], + hdrs = ["tensor_shape_utils.h"], + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/c:tf_tensor", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "tensor_shape_utils_test", + srcs = ["tensor_shape_utils_test.cc"], + deps = [ + ":tensor_shape_utils", + "//tensorflow/c:tf_tensor_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + # Changes to the Android srcs here should be replicated in # tensorflow/contrib/makefile/tf_op_files.txt. # @@ -59,11 +139,17 @@ filegroup( name = "android_all_op_kernels", srcs = [ "bitcast_op.cc", + "summary_op.cc", + "tensor_shape_utils.cc", + "tensor_shape_utils.h", ], ) # LINT.ThenChange(//tensorflow/contrib/makefile/tf_op_files.txt) filegroup( name = "android_all_ops", - srcs = ["ops/bitcast.cc"], + srcs = [ + "ops/bitcast.cc", + "ops/summary.cc", + ], ) diff --git a/tensorflow/c/kernels/ops/bitcast.cc b/tensorflow/c/kernels/ops/bitcast.cc index 3ba56411c38..0bc9fe86f10 100644 --- a/tensorflow/c/kernels/ops/bitcast.cc +++ b/tensorflow/c/kernels/ops/bitcast.cc @@ -22,8 +22,19 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" static void ComputeNewShape(TF_ShapeInferenceContext* ctx, - TF_ShapeHandle* shape, size_t input_type_size, - size_t output_type_size, TF_Status* status) { + TF_ShapeHandle* shape, TF_DataType input_type, + TF_DataType output_type, TF_Status* status) { + size_t input_type_size = TF_DataTypeSize(input_type); + size_t output_type_size = TF_DataTypeSize(output_type); + + if (input_type_size == 0 || output_type_size == 0) { + std::ostringstream err; + err << "Cannot bitcast type " << input_type << " to " << output_type + << " because one of the type sizes is zero"; + TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str()); + return; + } + TF_SetStatus(status, TF_OK, ""); if (input_type_size < output_type_size) { TF_ShapeInferenceContextWithRankAtLeast(ctx, shape, 1, shape, status); @@ -37,9 +48,9 @@ static void ComputeNewShape(TF_ShapeInferenceContext* ctx, TF_ShapeInferenceContextSubshape(ctx, shape, 0, -1, shape, status); } else { std::ostringstream err; - err << "Cannot bitcast due to shape. " - << TF_DimensionHandleValue(last_dim) << " does not match " - << divisor_val; + err << "Cannot bitcast from " << input_type << " to " << output_type + << " due to shape. " << TF_DimensionHandleValue(last_dim) + << " does not match " << divisor_val; TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str()); } TF_DeleteDimensionHandle(last_dim); @@ -78,23 +89,8 @@ static void bitcast_shape_inference_fn(TF_ShapeInferenceContext* ctx, TF_ShapeInferenceContext_GetAttrType(ctx, "type", &output_type, status); } - size_t input_type_size; - size_t output_type_size; - if (TF_GetCode(status) == TF_OK) { - input_type_size = TF_DataTypeSize(input_type); - output_type_size = TF_DataTypeSize(output_type); - - if (input_type_size == 0 || output_type_size == 0) { - std::ostringstream err; - err << "Cannot bitcast type " << input_type << " to " << output_type - << " because one of the type sizes is zero"; - TF_SetStatus(status, TF_INVALID_ARGUMENT, err.str().c_str()); - } - } - - if (TF_GetCode(status) == TF_OK) { - ComputeNewShape(ctx, result, input_type_size, output_type_size, status); + ComputeNewShape(ctx, result, input_type, output_type, status); } if (TF_GetCode(status) == TF_OK) { diff --git a/tensorflow/c/kernels/ops/summary.cc b/tensorflow/c/kernels/ops/summary.cc new file mode 100644 index 00000000000..b6b37f6b5b4 --- /dev/null +++ b/tensorflow/c/kernels/ops/summary.cc @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/ops.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" + +static void scalar_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx); + TF_ShapeInferenceContextSetOutput(ctx, 0, result, status); + TF_DeleteShapeHandle(result); +} + +void Register_ScalarSummaryOp() { + TF_Status* status = TF_NewStatus(); + + TF_OpDefinitionBuilder* op_builder = + TF_NewOpDefinitionBuilder("ScalarSummary"); + TF_OpDefinitionBuilderAddInput(op_builder, "tags: string"); + TF_OpDefinitionBuilderAddInput(op_builder, "values: T"); + TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string"); + TF_OpDefinitionBuilderAddAttr(op_builder, "T: realnumbertype"); + TF_OpDefinitionBuilderSetShapeInferenceFunction( + op_builder, &scalar_summary_shape_inference_fn); + + TF_RegisterOpDefinition(op_builder, status); + CHECK_EQ(TF_GetCode(status), TF_OK) + << "ScalarSummary op registration failed: " << TF_Message(status); + TF_DeleteStatus(status); +} + +TF_ATTRIBUTE_UNUSED static bool SummaryScalarOpRegistered = []() { + if (SHOULD_REGISTER_OP("ScalarSummary")) { + Register_ScalarSummaryOp(); + } + return true; +}(); diff --git a/tensorflow/c/kernels/summary_op.cc b/tensorflow/c/kernels/summary_op.cc new file mode 100644 index 00000000000..bd528da4165 --- /dev/null +++ b/tensorflow/c/kernels/summary_op.cc @@ -0,0 +1,172 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/c/kernels.h" +#include "tensorflow/c/kernels/tensor_shape_utils.h" +#include "tensorflow/c/tf_status.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/framework/selective_registration.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" + +namespace { + +// Struct that stores the status and TF_Tensor inputs to the opkernel. +// Used to delete tensor and status in its destructor upon kernel return. +struct Params { + TF_Tensor* tags; + TF_Tensor* values; + TF_Status* status; + explicit Params(TF_OpKernelContext* ctx) + : tags(nullptr), values(nullptr), status(nullptr) { + status = TF_NewStatus(); + TF_GetInput(ctx, 0, &tags, status); + if (TF_GetCode(status) == TF_OK) { + TF_GetInput(ctx, 1, &values, status); + } + } + ~Params() { + TF_DeleteStatus(status); + TF_DeleteTensor(tags); + TF_DeleteTensor(values); + } +}; + +// dummy functions used for kernel registration +void* ScalarSummaryOp_Create(TF_OpKernelConstruction* ctx) { return nullptr; } + +void ScalarSummaryOp_Delete(void* kernel) {} + +// Helper functions for compute method +bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2); +// Returns a string representation of a single tag or empty string if there +// are multiple tags +std::string SingleTag(TF_Tensor* tags); + +template +void ScalarSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) { + Params params(ctx); + if (TF_GetCode(params.status) != TF_OK) { + TF_OpKernelContext_Failure(ctx, params.status); + return; + } + if (!IsSameSize(params.tags, params.values)) { + std::ostringstream err; + err << "tags and values are not the same shape: " + << tensorflow::ShapeDebugString(params.tags) + << " != " << tensorflow::ShapeDebugString(params.values) + << SingleTag(params.tags); + TF_SetStatus(params.status, TF_INVALID_ARGUMENT, err.str().c_str()); + TF_OpKernelContext_Failure(ctx, params.status); + return; + } + // Convert tags and values tensor to array to access elements by index + tensorflow::Summary s; + auto tags_array = + static_cast(TF_TensorData(params.tags)); + auto values_array = static_cast(TF_TensorData(params.values)); + // Copy tags and values into summary protobuf + for (int i = 0; i < TF_TensorElementCount(params.tags); ++i) { + tensorflow::Summary::Value* v = s.add_value(); + const tensorflow::tstring& Ttags_i = tags_array[i]; + v->set_tag(Ttags_i.data(), Ttags_i.size()); + v->set_simple_value(static_cast(values_array[i])); + } + TF_Tensor* summary_tensor = + TF_AllocateOutput(ctx, 0, TF_ExpectedOutputDataType(ctx, 0), nullptr, 0, + sizeof(tensorflow::tstring), params.status); + if (TF_GetCode(params.status) != TF_OK) { + TF_DeleteTensor(summary_tensor); + TF_OpKernelContext_Failure(ctx, params.status); + return; + } + tensorflow::tstring* output_tstring = + reinterpret_cast(TF_TensorData(summary_tensor)); + CHECK(SerializeToTString(s, output_tstring)); + TF_DeleteTensor(summary_tensor); +} + +bool IsSameSize(TF_Tensor* tensor1, TF_Tensor* tensor2) { + if (TF_NumDims(tensor1) != TF_NumDims(tensor2)) { + return false; + } + for (int d = 0; d < TF_NumDims(tensor1); d++) { + if (TF_Dim(tensor1, d) != TF_Dim(tensor2, d)) { + return false; + } + } + return true; +} + +std::string SingleTag(TF_Tensor* tags) { + if (TF_TensorElementCount(tags) == 1) { + const char* single_tag = + static_cast(TF_TensorData(tags))->c_str(); + return tensorflow::strings::StrCat(" (tag '", single_tag, "')"); + } else { + return ""; + } +} + +template +void RegisterScalarSummaryOpKernel() { + TF_Status* status = TF_NewStatus(); + { + auto* builder = TF_NewKernelBuilder( + "ScalarSummary", tensorflow::DEVICE_CPU, &ScalarSummaryOp_Create, + &ScalarSummaryOp_Compute, &ScalarSummaryOp_Delete); + TF_KernelBuilder_TypeConstraint( + builder, "T", + static_cast(tensorflow::DataTypeToEnum::v()), status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint"; + TF_RegisterKernelBuilder("ScalarSummary", builder, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) + << "Error while registering Scalar Summmary kernel"; + } + TF_DeleteStatus(status); +} + +// A dummy static variable initialized by a lambda whose side-effect is to +// register the ScalarSummary kernel. +TF_ATTRIBUTE_UNUSED bool IsScalarSummaryOpKernelRegistered = []() { + if (SHOULD_REGISTER_OP_KERNEL("ScalarSummary")) { + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + RegisterScalarSummaryOpKernel(); + } + return true; +}(); +} // namespace diff --git a/tensorflow/c/kernels/summary_op_benchmark_test.cc b/tensorflow/c/kernels/summary_op_benchmark_test.cc index 9a68d5ddec1..7c1ab1f7103 100644 --- a/tensorflow/c/kernels/summary_op_benchmark_test.cc +++ b/tensorflow/c/kernels/summary_op_benchmark_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/graph/node_builder.h" @@ -20,19 +22,20 @@ limitations under the License. #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/framework/tensor_shape.h" -namespace tensorflow { +namespace tensorflow { +namespace { -static Graph* BM_ScalarSummaryOp(TensorShape shape, const char* tag, - float value) { +Graph* BM_ScalarSummaryOp(TensorShape shape, std::string tag, + float value) { Graph* g = new Graph(OpRegistry::Global()); Tensor tags(DT_STRING, shape); Tensor values(DT_FLOAT, shape); for (int i = 0; i < tags.NumElements(); ++i){ - tags.flat()(i) = tag; - values.flat()(i) = value; + tags.flat()(i) = tag; + values.flat()(i) = value; } Node* ret; - TF_CHECK_OK(NodeBuilder(g->NewName("dummy"), "SummaryScalar") + TF_CHECK_OK(NodeBuilder(g->NewName("dummy"), "ScalarSummary") .Input(test::graph::Constant(g, tags)) .Input(test::graph::Constant(g, values)) .Attr("T", DT_FLOAT) @@ -42,23 +45,27 @@ static Graph* BM_ScalarSummaryOp(TensorShape shape, const char* tag, // Macro used to parse initializer list for tensorshape #define DIMARGS(...) {__VA_ARGS__} -// Random parameters for testing -constexpr char longTagParam = "LONGTAG____________________________"; +// // Random parameters for testing +constexpr char longTagParam[] = "LONGTAG____________________________"; constexpr float largeValueParam = 2352352.2623433; -#define BM_ScalarSummaryDev(device, dims, name, tag, value) \ - static void BM_ScalarSummary_##name##_##device(int iters) { \ - TensorShape tensorshape(DIMARGS(dims)); \ - test::Benchmark(#device, BM_ScalarSummaryOp( \ - tensorshape, #tag, value)).Run(iters); \ - } \ - BENCHMARK(BM_ScalarSummary_##name##_##device); +#define BM_ScalarSummaryDev(device, dims, name, tag, value) \ + void BM_ScalarSummary##name##device(int iters) { \ + testing::StopTiming(); \ + TensorShape tensorshape(DIMARGS dims); \ + auto g = BM_ScalarSummaryOp(tensorshape, #tag, value); \ + testing::StartTiming(); \ + test::Benchmark("cpu", g).Run(iters); \ + } \ + BENCHMARK(BM_ScalarSummary##name##device); -BM_ScalarSummaryDev(cpu, (5, 10, 100), Base, tag, 5.2); +BM_ScalarSummaryDev(Cpu, (5, 10, 100), Base, Tag, 5.2); // Benchmark for large shapes -BM_ScalarSummaryDev(cpu, (500, 1000, 10000), Large_Shape, tag, 5.2); +BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeShape, Tag, 5.2); // Benchmark for large tag tstring -BM_ScalarSummaryDev(cpu, (5, 10, 100), Long_Tag, longTagParam, 5.2); +BM_ScalarSummaryDev(Cpu, (5, 10, 100), LongTag, longTagParam, 5.2); // Benchmark for large values -BM_ScalarSummaryDev(cpu, (500, 1000, 10000), Large_Value, tag, largeValueParam); -} // namespace tensorflow \ No newline at end of file +BM_ScalarSummaryDev(Cpu, (500, 100, 100), LargeValue, Tag, largeValueParam); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/kernels/summary_op_test.cc b/tensorflow/c/kernels/summary_op_test.cc new file mode 100644 index 00000000000..68c8deb5eab --- /dev/null +++ b/tensorflow/c/kernels/summary_op_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/kernels.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" + +namespace tensorflow { +namespace { + +class DummyDevice : public DeviceBase { + public: + explicit DummyDevice(Env* env) : DeviceBase(env) {} + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } +}; + +// Helper for comparing ouput and expected output +void ExpectSummaryMatches(const Summary& actual, const string& expected_str) { + Summary expected; + ASSERT_TRUE(protobuf::TextFormat::ParseFromString(expected_str, &expected)); + EXPECT_EQ(expected.DebugString(), actual.DebugString()); +} + +void TestScalarSummaryOp(Tensor* tags, Tensor* values, string expected_output, + error::Code expected_code) { + // Initialize node used to fetch OpKernel + Status status; + NodeDef def; + def.set_op("ScalarSummary"); + + def.set_device(DEVICE_CPU); + + AttrValue valuesTypeAttr; + SetAttrValue(values->dtype(), &valuesTypeAttr); + (*def.mutable_attr())["T"] = valuesTypeAttr; + + def.add_input(strings::StrCat("input1: ", DataTypeString(tags->dtype()))); + def.add_input(strings::StrCat("input2: ", DataTypeString(values->dtype()))); + + std::unique_ptr kernel = + CreateOpKernel(DeviceType(DEVICE_CPU), nullptr, nullptr, def, 1, &status); + ASSERT_TRUE(status.ok()) << status.ToString(); + OpKernelContext::Params params; + DummyDevice dummy_device(nullptr); + params.device = &dummy_device; + params.op_kernel = kernel.get(); + AllocatorAttributes alloc_attrs; + params.output_attr_array = &alloc_attrs; + gtl::InlinedVector inputs; + inputs.emplace_back(tags); + inputs.emplace_back(values); + params.inputs = &inputs; + OpKernelContext ctx(¶ms, 1); + kernel->Compute(&ctx); + ASSERT_EQ(expected_code, ctx.status().code()); + if (expected_code == error::OK) { + Summary summary; + ASSERT_TRUE(ParseProtoUnlimited( + &summary, ctx.mutable_output(0)->scalar()())); + ExpectSummaryMatches(summary, expected_output); + } else { + EXPECT_TRUE(absl::StrContains(ctx.status().ToString(), expected_output)) + << ctx.status(); + } +} + +TEST(ScalarSummaryOpTest, SimpleFloat) { + int vectorSize = 3; + Tensor tags(DT_STRING, {vectorSize}); + Tensor values(DT_FLOAT, {vectorSize}); + tags.vec()(0) = "tag1"; + tags.vec()(1) = "tag2"; + tags.vec()(2) = "tag3"; + values.vec()(0) = 1.0f; + values.vec()(1) = -0.73f; + values.vec()(2) = 10000.0f; + TestScalarSummaryOp(&tags, &values, R"( + value { tag: 'tag1' simple_value: 1.0 } + value { tag: 'tag2' simple_value: -0.73} + value { tag: 'tag3' simple_value: 10000.0})", + error::OK); +} + +TEST(ScalarSummaryOpTest, SimpleDouble) { + int vectorSize = 3; + Tensor tags(DT_STRING, {vectorSize}); + Tensor values(DT_DOUBLE, {vectorSize}); + tags.vec()(0) = "tag1"; + tags.vec()(1) = "tag2"; + tags.vec()(2) = "tag3"; + values.vec()(0) = 1.0; + values.vec()(1) = -0.73; + values.vec()(2) = 10000.0; + TestScalarSummaryOp(&tags, &values, R"( + value { tag: 'tag1' simple_value: 1.0 } + value { tag: 'tag2' simple_value: -0.73} + value { tag: 'tag3' simple_value: 10000.0})", + error::OK); +} + +TEST(ScalarSummaryOpTest, SimpleHalf) { + int vectorSize = 3; + Tensor tags(DT_STRING, {vectorSize}); + Tensor values(DT_HALF, {vectorSize}); + tags.vec()(0) = "tag1"; + tags.vec()(1) = "tag2"; + tags.vec()(2) = "tag3"; + values.vec()(0) = Eigen::half(1.0); + values.vec()(1) = Eigen::half(-2.0); + values.vec()(2) = Eigen::half(10000.0); + TestScalarSummaryOp(&tags, &values, R"( + value { tag: 'tag1' simple_value: 1.0 } + value { tag: 'tag2' simple_value: -2.0} + value { tag: 'tag3' simple_value: 10000.0})", + error::OK); +} + +TEST(ScalarSummaryOpTest, Error_WrongDimsTags) { + Tensor tags(DT_STRING, {2, 1}); + Tensor values(DT_FLOAT, {2}); + tags.matrix()(0, 0) = "tag1"; + tags.matrix()(1, 0) = "tag2"; + values.vec()(0) = 1.0f; + values.vec()(1) = -2.0f; + TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape", + error::INVALID_ARGUMENT); +} + +TEST(ScalarSummaryOpTest, Error_WrongValuesTags) { + Tensor tags(DT_STRING, {2}); + Tensor values(DT_FLOAT, {2, 1}); + tags.vec()(0) = "tag1"; + tags.vec()(1) = "tag2"; + values.matrix()(0, 0) = 1.0f; + values.matrix()(1, 0) = -2.0f; + TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape", + error::INVALID_ARGUMENT); +} + +TEST(ScalarSummaryOpTest, Error_WrongWithSingleTag) { + Tensor tags(DT_STRING, {1}); + Tensor values(DT_FLOAT, {2, 1}); + tags.vec()(0) = "tag1"; + values.matrix()(0, 0) = 1.0f; + values.matrix()(1, 0) = -2.0f; + TestScalarSummaryOp(&tags, &values, "tags and values are not the same shape", + error::INVALID_ARGUMENT); +} + +TEST(ScalarSummaryOpTest, IsRegistered) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("ScalarSummary", ®)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/kernels/tensor_shape_utils.cc b/tensorflow/c/kernels/tensor_shape_utils.cc new file mode 100644 index 00000000000..967330ccb93 --- /dev/null +++ b/tensorflow/c/kernels/tensor_shape_utils.cc @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/kernels/tensor_shape_utils.h" + +#include + +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/strcat.h" + +namespace tensorflow { + +std::string ShapeDebugString(TF_Tensor* tensor) { + // A TF_Tensor cannot have an unknown rank. + CHECK_GE(TF_NumDims(tensor), 0); + tensorflow::string s = "["; + for (int i = 0; i < TF_NumDims(tensor); ++i) { + if (i > 0) tensorflow::strings::StrAppend(&s, ","); + int64_t dim = TF_Dim(tensor, i); + // A TF_Tensor cannot have an unknown dimension. + CHECK_GE(dim, 0); + tensorflow::strings::StrAppend(&s, dim); + } + tensorflow::strings::StrAppend(&s, "]"); + return s; +} +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h b/tensorflow/c/kernels/tensor_shape_utils.h similarity index 51% rename from tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h rename to tensorflow/c/kernels/tensor_shape_utils.h index 566417df025..bfe51bc1a2a 100644 --- a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h +++ b/tensorflow/c/kernels/tensor_shape_utils.h @@ -13,25 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ -#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ +// This file contains shape utilities to be used by kernels and is not part of +// the C API. As such, it is subject to change at any time. -#include +#ifndef TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_ +#define TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_ -#include "tensorflow/c/conversion_macros.h" -#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include -// Internal structures used by the SavedModel C API. These are likely to -// change and should not be depended on. - -typedef struct TF_TensorHandleList TF_TensorHandleList; +#include "tensorflow/c/tf_tensor.h" namespace tensorflow { -DEFINE_CONVERSION_FUNCTIONS( - std::vector, - TF_TensorHandleList) +// The following are utils for the shape of a TF_Tensor type. +// These functions may later be subsumed by the methods for a +// TF_TensorShape type. + +// Returns a string representation of the TF_Tensor shape. +std::string ShapeDebugString(TF_Tensor* tensor); } // namespace tensorflow -#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_CONCRETE_FUNCTION_LIST_TYPE_H_ +#endif // TENSORFLOW_C_TENSOR_SHAPE_UTILS_H_ diff --git a/tensorflow/c/kernels/tensor_shape_utils_test.cc b/tensorflow/c/kernels/tensor_shape_utils_test.cc new file mode 100644 index 00000000000..783105f3ad7 --- /dev/null +++ b/tensorflow/c/kernels/tensor_shape_utils_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/kernels/tensor_shape_utils.h" + +#include "tensorflow/c/tf_tensor_internal.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { + +// A wrapper that will automatically delete the allocated TF_Tensor +// once out of scope. +struct TF_TensorWrapper { + TF_Tensor* tf_tensor; + explicit TF_TensorWrapper(TF_Tensor* tensor) { tf_tensor = tensor; } + ~TF_TensorWrapper() { TF_DeleteTensor(tf_tensor); } +}; + +void TestShapeMatch(TensorShape shape) { + Tensor tensor(DT_FLOAT, shape); + Status status; + TF_Tensor* tf_tensor = TF_TensorFromTensor(tensor, &status); + TF_TensorWrapper tensor_wrapper = TF_TensorWrapper(tf_tensor); + ASSERT_TRUE(status.ok()) << status.ToString(); + ASSERT_EQ(tensor.shape().DebugString(), ShapeDebugString(tf_tensor)); +} + +TEST(ShapeDebugString, RegularShape) { TestShapeMatch(TensorShape({5, 4, 7})); } + +TEST(ShapeDebugString, ScalarShape) { TestShapeMatch(TensorShape({})); } + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index 423302741de..3c8ac934428 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -73,6 +73,12 @@ static void* MyCreateFunc(TF_OpKernelConstruction* ctx) { EXPECT_EQ(TF_FLOAT, type); TF_DeleteStatus(status); + // Exercise kernel NodeDef name read + TF_StringView name_string_view = TF_OpKernelConstruction_GetName(ctx); + std::string node_name = "SomeNodeName"; + std::string candidate_node_name = + std::string(name_string_view.data, name_string_view.len); + EXPECT_EQ(node_name, candidate_node_name); return s; } @@ -96,9 +102,11 @@ namespace tensorflow { static std::unique_ptr GetFakeKernel(const char* device_name, const char* op_name, + const char* node_name, Status* status) { NodeDef def; def.set_op(op_name); + def.set_name(node_name); def.set_device(device_name); def.add_input("input1"); def.add_input("input2"); @@ -114,7 +122,7 @@ static std::unique_ptr GetFakeKernel(const char* device_name, // Tests registration of a single C kernel and checks that calls through the // C/C++ boundary are being made. TEST(TestKernel, TestRegisterKernelBuilder) { - const char* kernel_name = "SomeKernelName"; + const char* node_name = "SomeNodeName"; const char* op_name = "FooOp"; const char* device_name = "FakeDeviceName1"; @@ -129,7 +137,7 @@ TEST(TestKernel, TestRegisterKernelBuilder) { { TF_Status* status = TF_NewStatus(); - TF_RegisterKernelBuilder(kernel_name, builder, status); + TF_RegisterKernelBuilder(node_name, builder, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); @@ -144,7 +152,7 @@ TEST(TestKernel, TestRegisterKernelBuilder) { { Status status; std::unique_ptr kernel = - GetFakeKernel(device_name, op_name, &status); + GetFakeKernel(device_name, op_name, node_name, &status); TF_EXPECT_OK(status); ASSERT_NE(nullptr, kernel.get()); kernel->Compute(nullptr); @@ -162,7 +170,7 @@ class DummyDevice : public DeviceBase { }; TEST(TestKernel, TestInputAndOutputCount) { - const char* kernel_name = "InputOutputCounterKernel"; + const char* node_name = "InputOutputCounterKernel"; const char* op_name = "BarOp"; const char* device_name = "FakeDeviceName2"; @@ -212,7 +220,7 @@ TEST(TestKernel, TestInputAndOutputCount) { { TF_Status* status = TF_NewStatus(); - TF_RegisterKernelBuilder(kernel_name, builder, status); + TF_RegisterKernelBuilder(node_name, builder, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); TF_DeleteStatus(status); } @@ -233,7 +241,7 @@ TEST(TestKernel, TestInputAndOutputCount) { Status status; std::unique_ptr kernel = - GetFakeKernel(device_name, op_name, &status); + GetFakeKernel(device_name, op_name, node_name, &status); TF_EXPECT_OK(status); ASSERT_NE(nullptr, kernel.get()); @@ -252,7 +260,7 @@ TEST(TestKernel, DeleteKernelBuilderIsOkOnNull) { } TEST(TestKernel, TestTypeConstraint) { - const char* kernel_name = "SomeKernelName"; + const char* node_name = "SomeNodeName"; const char* op_name = "TypeOp"; const char* device_name = "FakeDeviceName1"; @@ -267,7 +275,7 @@ TEST(TestKernel, TestTypeConstraint) { TF_Status* status = TF_NewStatus(); TF_KernelBuilder_TypeConstraint(builder, "T", TF_DataType::TF_INT32, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); - TF_RegisterKernelBuilder(kernel_name, builder, status); + TF_RegisterKernelBuilder(node_name, builder, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); @@ -296,7 +304,7 @@ TEST(TestKernel, TestTypeConstraint) { } TEST(TestKernel, TestHostMemory) { - const char* kernel_name = "SomeKernelName"; + const char* node_name = "SomeNodeName"; const char* op_name = "HostMemoryOp"; const char* device_name = "FakeDeviceName1"; @@ -311,7 +319,7 @@ TEST(TestKernel, TestHostMemory) { TF_KernelBuilder_HostMemory(builder, "input2"); TF_KernelBuilder_HostMemory(builder, "output1"); TF_Status* status = TF_NewStatus(); - TF_RegisterKernelBuilder(kernel_name, builder, status); + TF_RegisterKernelBuilder(node_name, builder, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); TF_Buffer* buf = TF_GetRegisteredKernelsForOp(op_name, status); @@ -335,12 +343,12 @@ TEST(TestKernel, TestHostMemory) { class DeviceKernelOpTest : public OpsTestBase { protected: - void SetupOp(const char* op_name, const char* kernel_name, + void SetupOp(const char* op_name, const char* node_name, void (*compute_func)(void*, TF_OpKernelContext*)) { TF_KernelBuilder* builder = TF_NewKernelBuilder( op_name, device_name_, nullptr, compute_func, nullptr); TF_Status* status = TF_NewStatus(); - TF_RegisterKernelBuilder(kernel_name, builder, status); + TF_RegisterKernelBuilder(node_name, builder, status); EXPECT_EQ(TF_OK, TF_GetCode(status)); TF_DeleteStatus(status); diff --git a/tensorflow/c/logging.cc b/tensorflow/c/logging.cc new file mode 100644 index 00000000000..bf6bf069fff --- /dev/null +++ b/tensorflow/c/logging.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/logging.h" + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stringprintf.h" + +static ::tensorflow::string BuildMessage(const char* fmt, va_list args) { + ::tensorflow::string message; + ::tensorflow::strings::Appendv(&message, fmt, args); + return message; +} + +void TF_Log(TF_LogLevel level, const char* fmt, ...) { + if (level < TF_INFO || level > TF_FATAL) return; + va_list args; + va_start(args, fmt); + auto message = BuildMessage(fmt, args); + switch (level) { + case TF_INFO: + LOG(INFO) << message; + break; + case TF_WARNING: + LOG(WARNING) << message; + break; + case TF_ERROR: + LOG(ERROR) << message; + break; + case TF_FATAL: + LOG(FATAL) << message; + break; + } +} + +void TF_VLog(int level, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + auto message = BuildMessage(fmt, args); + VLOG(level) << message; +} + +void TF_DVLog(int level, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + auto message = BuildMessage(fmt, args); + DVLOG(level) << message; +} diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc b/tensorflow/c/logging.h similarity index 52% rename from tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc rename to tensorflow/c/logging.h index c8f00c1f7c0..9583777b661 100644 --- a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc +++ b/tensorflow/c/logging.h @@ -12,25 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_C_LOGGING_H_ +#define TENSORFLOW_C_LOGGING_H_ -#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h" +#include "tensorflow/c/c_api_macros.h" -#include - -#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" -#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" +// -------------------------------------------------------------------------- +// C API for tensorflow::Logging. +#ifdef __cplusplus extern "C" { +#endif -size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) { - return tensorflow::unwrap(list)->size(); +typedef enum TF_LogLevel { + TF_INFO = 0, + TF_WARNING = 1, + TF_ERROR = 2, + TF_FATAL = 3, +} TF_LogLevel; + +TF_CAPI_EXPORT extern void TF_Log(TF_LogLevel level, const char* fmt, ...); +TF_CAPI_EXPORT extern void TF_VLog(int level, const char* fmt, ...); +TF_CAPI_EXPORT extern void TF_DVLog(int level, const char* fmt, ...); + +#ifdef __cplusplus } +#endif -TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list, - int i) { - return tensorflow::wrap((*tensorflow::unwrap(list))[i]); -} - - -} // end extern "C" +#endif // TENSORFLOW_C_LOGGING_H_ diff --git a/tensorflow/c/ops.cc b/tensorflow/c/ops.cc index 118385ed72c..cc0eddfcbf6 100644 --- a/tensorflow/c/ops.cc +++ b/tensorflow/c/ops.cc @@ -104,6 +104,12 @@ TF_ShapeHandle* TF_NewShapeHandle() { return reinterpret_cast(new ShapeHandle); } +TF_ShapeHandle* TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext* ctx) { + auto* handle = new ShapeHandle; + *handle = reinterpret_cast(ctx)->Scalar(); + return reinterpret_cast(handle); +} + TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize( TF_ShapeInferenceContext* ctx, size_t size) { auto* handle = new ShapeHandle; diff --git a/tensorflow/c/ops.h b/tensorflow/c/ops.h index 14868e40260..7463809e35b 100644 --- a/tensorflow/c/ops.h +++ b/tensorflow/c/ops.h @@ -280,6 +280,11 @@ extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx, int i, TF_ShapeHandle* handle, TF_Status* status); +// Returns a newly-allocated scalar shape handle. The returned handle should +// be freed with TF_DeleteShapeHandle. +TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextScalar( + TF_ShapeInferenceContext* ctx); + // Returns a newly-allocate shape handle representing a vector of the given // size. The returned handle should be freed with TF_DeleteShapeHandle. TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize( diff --git a/tensorflow/c/ops_test.cc b/tensorflow/c/ops_test.cc index 482413f966c..9fbf4dcbf8b 100644 --- a/tensorflow/c/ops_test.cc +++ b/tensorflow/c/ops_test.cc @@ -316,5 +316,16 @@ TEST(OpsTest, ShapeInferenceSubshape) { TF_DeleteShapeHandle(handle); } +TEST(OpsTest, ShapeInferenceScalarShape) { + NodeDef def; + shape_inference::InferenceContext c(0, def, MakeOpDef(0, 0), {S({})}, {}, {}, + {}); + TF_ShapeHandle* TF_scalar_shape = TF_ShapeInferenceContextScalar(C_CTX(&c)); + shape_inference::ShapeHandle* scalar_shape = + reinterpret_cast(TF_scalar_shape); + ASSERT_EQ("[]", c.DebugString(*scalar_shape)); + TF_DeleteShapeHandle(TF_scalar_shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 0feb986ce44..39d2683226f 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -288,7 +288,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) { if (!tensor.CopyFrom(src, src.shape())) { return nullptr; } - return new TF_Tensor{new tensorflow::TensorInterface(tensor)}; + return new TF_Tensor{new tensorflow::TensorInterface(std::move(tensor))}; } Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index e1fad8e697a..8602bfafff8 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -558,6 +558,7 @@ tf_gen_op_wrappers_cc( "io_ops", "linalg_ops", "list_ops", + "map_ops", "logging_ops", "lookup_ops", "manip_ops", diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index 88cd3fe79d6..3195a357186 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -425,7 +425,7 @@ Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, // Backprop along the in edges to the while loop (i.e. the inputs to the enter // nodes) DCHECK_EQ(dx.size(), while_ctx->enter_nodes().size()); - for (int i = 0; i < dx.size(); ++i) { + for (int i = 0, end = dx.size(); i < end; ++i) { Node* enter_node = while_ctx->enter_nodes()[i]; for (const Edge* e : enter_node->in_edges()) { if (e->IsControlEdge()) continue; @@ -489,7 +489,7 @@ Status SymbolicGradientBuilder::AddGradients() { // All loop-specific control flow ops should have been handled above DCHECK(!n->IsEnter() && !n->IsNextIteration()) << n->DebugString(); - const size_t num_no_grad = no_grad_dy_indices.size(); + const int num_no_grad = no_grad_dy_indices.size(); if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) { // No grad defined for this op, or all outputs returned 'NoGradient': // Backprop 'NoGradient' along the in edges. @@ -524,7 +524,7 @@ Status SymbolicGradientBuilder::AddGradients() { // make this association explicit. for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) continue; - int dx_index = e->dst_input(); + size_t dx_index = e->dst_input(); if (dx_index >= dx.size()) { return errors::Internal( "Invalid gradient output index: ", dx_index, " size: ", dx.size()); diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc index 81870a0efa3..e241cfaebe9 100644 --- a/tensorflow/cc/framework/while_gradients.cc +++ b/tensorflow/cc/framework/while_gradients.cc @@ -34,7 +34,7 @@ Output ToOutput(OutputTensor output_tensor) { std::vector ToOutputVector( const std::vector& output_tensors) { - size_t n = output_tensors.size(); + const int n = output_tensors.size(); std::vector result; result.reserve(n); for (int i = 0; i < n; ++i) result.push_back(ToOutput(output_tensors[i])); diff --git a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc index ad80b74f1d5..cf5f742538e 100644 --- a/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc +++ b/tensorflow/cc/saved_model/experimental/tests/saved_model_api_test.cc @@ -86,11 +86,7 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModel) { std::unique_ptr model = SavedModelAPI::Load(model_dir, *runtime, &status); - // TODO(bmzhao): Change this to expect TF_OK when loading is implemented. - // That unblocks writing other tests that require a TF_SavedModel*, - // like loading a ConcreteFunction. This test at least checks that the - // C API builds and can be minimally run. - EXPECT_EQ(status.code(), TF_UNIMPLEMENTED) << status.message(); + EXPECT_EQ(status.code(), TF_OK) << status.message(); } INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests, diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index d091146c75a..ff255dd9cc1 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -308,6 +308,8 @@ cc_library( ], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:framework", ], alwayslink = 1, diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index e4df3090046..ae50a447b19 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -172,7 +172,7 @@ string RewriteWithName(const string& name, string code, Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, const CompileResult& compile_result, string* methods) { - size_t num_args = ps.parameters_size(); + const int num_args = ps.parameters_size(); // feed_size() + variable_size() is the maximum number of args as an // implementation may not create an argument for an unused variable. if (config.feed_size() + config.variable_size() < num_args) { @@ -229,8 +229,9 @@ Status GenResultMethods(const tf2xla::Config& config, int readonly_variables = absl::c_count_if( config.variable(), [](const tf2xla::Variable& var) { return var.readonly(); }); - if (config.fetch_size() + config.variable_size() - readonly_variables != - num_results) { + const int actual_num_results = + config.fetch_size() + config.variable_size() - readonly_variables; + if (actual_num_results != num_results) { return errors::InvalidArgument("mismatch between fetch_size(", config.fetch_size(), ")+variable_size(", config.variable_size(), ") and tuple_size(", @@ -273,7 +274,7 @@ Status GenResultMethods(const tf2xla::Config& config, // Generate methods for variables. Status GenVariableMethods(const tf2xla::Config& config, const xla::ProgramShapeProto& ps, string* methods) { - size_t num_args = ps.parameters_size(); + const int num_args = ps.parameters_size(); for (int i = config.feed_size(); i < num_args; ++i) { std::vector> rewrites; TF_RETURN_IF_ERROR( @@ -401,7 +402,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); std::vector buffer_infos_as_strings = BufferInfosToCppExpression(buffer_infos); - if (result_index < 0 || result_index >= buffer_infos.size()) { + const int64 buffer_infos_size = buffer_infos.size(); + if (result_index < 0 || result_index >= buffer_infos_size) { return errors::InvalidArgument("result index: ", result_index, " is outside the range of temp sizes: [0,", buffer_infos.size(), ")"); @@ -797,8 +799,8 @@ Status ParseCppClass(const string& cpp_class, string* class_name, // Allow a fully qualified name that starts with "::". parts.erase(parts.begin()); } - for (int i = 0; i < parts.size(); ++i) { - if (i < parts.size() - 1) { + for (int i = 0, end = parts.size(); i < end; ++i) { + if (i < end - 1) { TF_RETURN_IF_ERROR(ValidateCppIdent( parts[i], "in namespace component of cpp_class: " + cpp_class)); namespaces->push_back(parts[i]); diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 0c44ed8bf37..5f6b3dc7101 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -63,7 +63,7 @@ py_binary( testonly = 1, srcs = ["make_test_graphs.py"], python_version = "PY3", - srcs_version = "PY2AND3", + srcs_version = "PY3", deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python", # TODO(b/34059704): remove when fixed @@ -110,8 +110,8 @@ genrule( # have control of the full GPU. cmd = "CUDA_VISIBLE_DEVICES='' " + "$(location :make_test_graphs) --out_dir $(@D)", + exec_tools = [":make_test_graphs"], tags = ["manual"], - tools = [":make_test_graphs"], ) tf_library( diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 5ec0575ed77..d05bb8264c3 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -95,6 +95,7 @@ cc_library( ":xla_kernel_creator", # buildcleaner: keep "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", @@ -115,6 +116,7 @@ cc_library( ":xla_kernel_creator", # buildcleaner: keep "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", @@ -126,22 +128,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "xla_interpreter_device", - srcs = ["xla_interpreter_device.cc"], - visibility = [":friends"], - deps = [ - ":jit_compilation_passes", - ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_ops", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep - "@com_google_absl//absl/memory", - ], - alwayslink = 1, -) - cc_library( name = "xla_tensor", srcs = ["xla_tensor.cc"], @@ -172,6 +158,7 @@ XLA_DEVICE_DEPS = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", @@ -208,6 +195,7 @@ XLA_DEVICE_DEPS = [ "//tensorflow/core/kernels/data:optional_ops", "//tensorflow/core/kernels/data:prefetch_dataset_op", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/stream_executor:tf_allocator_adapter", "//tensorflow/stream_executor/platform", ] @@ -218,14 +206,18 @@ cc_library( "xla_device.cc", "xla_device_context.cc", "xla_device_ops.cc", + "xla_platform_info.cc", ], hdrs = [ "xla_compile_on_demand_op.h", "xla_device.h", "xla_device_context.h", "xla_device_ops.h", + "xla_platform_info.h", ], - deps = XLA_DEVICE_DEPS, + # Public visibility is needed for external TF/XLA backends. + visibility = ["//visibility:public"], + deps = XLA_DEVICE_DEPS + [":xla_compilation_cache"], ) cc_library( @@ -341,6 +333,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", @@ -388,22 +381,26 @@ cc_library( alwayslink = 1, ) -# Linked by tensorflow core, without registration of jit compilation passes -# which is not necessary to create and run a XlaLocalLaunchBase kernel. -# Linking jit compilation passes could cause programs stuck right now (b/140069592). cc_library( - name = "xla_kernel_creator_util", + name = "xla_kernel_creator", srcs = [ - "xla_kernel_creator_util.cc", + "xla_kernel_creator.cc", + "xla_kernel_creator.h", + ], + visibility = [ + ":internal", + "//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__", + "//tensorflow/core/common_runtime/eager:__pkg__", ], - hdrs = ["xla_kernel_creator_util.h"], - visibility = ["//tensorflow/core/common_runtime/eager:__pkg__"], deps = [ ":common", ":compilability_check_util", ":compilation_passes", + ":flags", + ":jit_compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -415,25 +412,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "xla_kernel_creator", - srcs = [ - "xla_kernel_creator.cc", - "xla_kernel_creator.h", - ], - deps = [ - ":compilability_check_util", - ":flags", - ":jit_compilation_passes", - ":xla_kernel_creator_util", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], - alwayslink = 1, -) - tf_cc_test( name = "xla_kernel_creator_test", srcs = [ @@ -639,6 +617,7 @@ cc_library( "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", @@ -648,11 +627,11 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:framework_bounds_check", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/framework:bounds_check", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -677,11 +656,11 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", - "//tensorflow/core:framework_bounds_check", "//tensorflow/core:framework_internal", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/framework:bounds_check", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -698,6 +677,7 @@ cc_library( hdrs = ["device_util.h"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", @@ -912,6 +892,7 @@ cc_library( "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 5a57008cf61..a340b9d3f45 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -452,7 +452,7 @@ Status PredicateInt32Inputs(const Scope& root, Node* n, root.graph()->AddControlEdge(predicate_as_control.node(), identity_n.operation.node()); - for (int i = 0; i < int32_inputs.size(); i++) { + for (int i = 0, end = int32_inputs.size(); i < end; i++) { TF_RETURN_IF_ERROR(root.graph()->UpdateEdge(identity_n[i].node(), i, n, int32_inputs_input_idxs[i])); } diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index a21cb6b98dd..3b20784cc29 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -257,7 +257,7 @@ class RecursiveCompilabilityChecker { UncompilableNodesMap* uncompilable_nodes_map); // Make sure we don't recurse infinitely on recursive functions. - const int kMaxRecursionDepth = 10; + const size_t kMaxRecursionDepth = 10; const OperationFilter& op_filter_; const DeviceType& jit_device_type_; diff --git a/tensorflow/compiler/jit/device_util.cc b/tensorflow/compiler/jit/device_util.cc index 375d30c4cf3..d8749baf872 100644 --- a/tensorflow/compiler/jit/device_util.cc +++ b/tensorflow/compiler/jit/device_util.cc @@ -26,8 +26,8 @@ using xla::StatusOr; void DeviceSet::Insert(DeviceId device_id) { int word_index = device_id.id() / kWordSize; int bit_index = device_id.id() % kWordSize; - - if (word_index >= storage_.size()) { + const int storage_size = storage_.size(); + if (word_index >= storage_size) { storage_.resize(word_index + 1, 0); } @@ -39,7 +39,7 @@ void DeviceSet::UnionWith(const DeviceSet& other) { storage_.resize(other.storage_.size(), 0); } - for (int i = 0; i < other.storage_.size(); i++) { + for (int i = 0, end = other.storage_.size(); i < end; i++) { storage_[i] |= other.storage_[i]; } } diff --git a/tensorflow/compiler/jit/device_util.h b/tensorflow/compiler/jit/device_util.h index 35f3321b47b..6304cc813ca 100644 --- a/tensorflow/compiler/jit/device_util.h +++ b/tensorflow/compiler/jit/device_util.h @@ -72,7 +72,8 @@ class DeviceSet { void ForEach(FnTy func) const { // This is really a poor man's iterator, we should consider writing a proper // iterator if this ends up being used widely. - for (int word_index = 0; word_index < storage_.size(); word_index++) { + for (int word_index = 0, end = storage_.size(); word_index < end; + word_index++) { uint64 word = storage_[word_index]; while (word != 0) { uint64 only_lowest_bit_set = word & -word; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 435c2ec5f7f..d482642b44c 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -1132,7 +1132,8 @@ static Status GetArgTypes(const Graph& graph, DataTypeVector* types) { if (n->type_string() == kArgOp) { int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); - if (index < 0 || index >= types->size()) { + const int num_types = types->size(); + if (index < 0 || index >= num_types) { return errors::InvalidArgument("Invalid argument number"); } (*types)[index] = n->output_type(0); @@ -1149,7 +1150,8 @@ static Status RenumberArguments(Graph* graph, if (n->type_string() == kArgOp) { int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); - if (index < 0 || index >= permutation.size()) { + const int permutation_size = permutation.size(); + if (index < 0 || index >= permutation_size) { return errors::InvalidArgument("Invalid argument number"); } n->AddAttr("index", permutation[index]); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 6640a5d5dba..efd2ef24c3b 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -295,19 +295,6 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, << diff << "\nActual: " << actual.DebugString(); \ } while (false) -// These dummy Op registrations are here because the real Op registrations live -// in contrib and there can't be a dependence from this test to contrib. -REGISTER_OP("XlaHostCompute") - .Input("inputs: Tinputs") - .Output("outputs: Toutputs") - .Attr("Tinputs: list(type) >= 0") - .Attr("Toutputs: list(type) >= 0") - .Attr("ancestors: list(string) >= 0") - .Attr("key: string") - .Attr("shape_inference_graph: func") - .Attr("shapes: list(shape) >= 0") - .SetShapeFn(::tensorflow::shape_inference::UnknownShape); - REGISTER_OP("InputTest") .Output("o: float") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { @@ -947,6 +934,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", @@ -1114,6 +1103,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"shape_inference_graph", shape_inference_graph2}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", @@ -1130,6 +1121,8 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", shape_inference_graph1}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", @@ -1266,6 +1259,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", NameAttrList()}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, @@ -1295,6 +1290,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_F2_O1"}, {"shape_inference_graph", NameAttrList()}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, @@ -1428,6 +1425,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", NameAttrList()}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, @@ -1454,6 +1453,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutsideDependencyFromOutside) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F2_F2_O1"}, {"shape_inference_graph", NameAttrList()}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, @@ -1566,6 +1567,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", NameAttrList()}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, @@ -1658,6 +1661,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", NameAttrList()}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({shape_proto_expected})}, {"_outside_compilation_subgraph", "O1"}, @@ -1765,6 +1770,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", @@ -1875,6 +1882,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", @@ -2009,6 +2018,8 @@ TEST(EncapsulateSubgraphsTest, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", shape_inference_graph1}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", @@ -2023,6 +2034,8 @@ TEST(EncapsulateSubgraphsTest, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"shape_inference_graph", shape_inference_graph2}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", @@ -2153,6 +2166,8 @@ TEST(EncapsulateSubgraphsTest, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"shape_inference_graph", NameAttrList()}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", @@ -2169,6 +2184,8 @@ TEST(EncapsulateSubgraphsTest, {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", @@ -2296,6 +2313,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", @@ -2310,6 +2329,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O2"}, {"shape_inference_graph", NameAttrList()}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O2"}, {"_xla_token_input_nodes", @@ -2325,6 +2346,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationClusterDependency) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O3"}, {"shape_inference_graph", NameAttrList()}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O3"}, {"_xla_token_input_nodes", @@ -2451,6 +2474,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", @@ -2567,6 +2592,8 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { {"ancestors", absl::Span({})}, {"key", "host_compute_channel_F1_F1_O1"}, {"shape_inference_graph", shape_inference_graph}, + {"tpu_core", 0}, + {"cost_estimate_ns", 1000000}, {"shapes", absl::Span({})}, {"_outside_compilation_subgraph", "O1"}, {"_xla_token_input_nodes", diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 5325f6faa31..12afee70716 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -139,7 +139,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations( // Remove the edge from host to outside compilation. Add a placeholder as // outside compilation node input. std::map, Node*> placeholders; - for (int i = 0; i < edges.size(); i++) { + for (int i = 0, end = edges.size(); i < end; i++) { Node* dst = g->FindNodeId(edges[i].dst_node_id); const Edge* e; TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e)); @@ -185,7 +185,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations( // Other edge in `edges` might have `e->dst()` as src or dst // node. Before removing `e->dst()`, replace those edges with // corresponding edges for `dst_replace_node`. - for (int j = i + 1; j < edges.size(); j++) { + for (int j = i + 1, end = edges.size(); j < end; j++) { if (edges[j].dst_node_id == edges[i].dst_node_id) { edges[j].dst_node_id = dst_replace_node->id(); } @@ -238,7 +238,7 @@ Status PostprocessDataEdgesBetweenOutsideCompilations( g->AddControlEdge(original_node, e->dst()); g->RemoveEdge(e); } - for (int i = 0; i < data_edges.size(); i++) { + for (int i = 0, end = data_edges.size(); i < end; i++) { Node* dst = data_edges[i].dst; NodeDef new_def = dst->def(); int dst_input = data_edges[i].dst_input; @@ -253,7 +253,7 @@ Status PostprocessDataEdgesBetweenOutsideCompilations( // Other edges might have `dst` as dst node. Update those edges with // `replace_node`. - for (int j = i + 1; j < data_edges.size(); j++) { + for (int j = i + 1, end = data_edges.size(); j < end; j++) { if (data_edges[j].dst == dst) { data_edges[j].dst = replace_node; } diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 2b7a6c83b8b..ed25baa62ff 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -351,14 +351,14 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, if (!status.ok()) { return status; } - for (int i = 0; i < data_inputs.size(); ++i) { + for (int i = 0, end = data_inputs.size(); i < end; ++i) { graph->AddEdge(data_inputs[i].first, data_inputs[i].second, xla_launch, i); } for (Node* n : control_inputs) { graph->AddControlEdge(n, xla_launch); } - for (int i = 0; i < data_outputs.size(); ++i) { + for (int i = 0, end = data_outputs.size(); i < end; ++i) { for (const auto& successor : data_outputs[i]) { graph->AddEdge(xla_launch, i, successor.first, successor.second); } diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index 5f1c3d536a8..fef43eb8730 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -95,7 +95,7 @@ Status GetArgDataTypes(const std::vector& arg_nodes, TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype)); (*recv_at_host_dtypes)[index] = dtype; } - for (int i = 0; i < recv_at_host_dtypes->size(); i++) { + for (int i = 0, end = recv_at_host_dtypes->size(); i < end; i++) { if ((*recv_at_host_dtypes)[i] == DT_INVALID) { return errors::Internal("Cannot get datatype for input ", i); } @@ -160,7 +160,7 @@ xla::StatusOr ReplaceArgNodesWithRecvAtHostNode( } // Rewrite dst nodes because their input changed. - for (int i = 0; i < out_edge_info.size(); i++) { + for (int i = 0, end = out_edge_info.size(); i < end; i++) { const OutEdgeInfo edge = out_edge_info[i]; if (edge.dst_input == Graph::kControlSlot) { continue; @@ -174,7 +174,7 @@ xla::StatusOr ReplaceArgNodesWithRecvAtHostNode( // Other edges might have `dst` as dst node as well. Update those edges // with `dst_replace`. - for (int j = i + 1; j < out_edge_info.size(); j++) { + for (int j = i + 1, end = out_edge_info.size(); j < end; j++) { if (out_edge_info[j].dst == dst) { out_edge_info[j].dst = dst_replace; } @@ -196,7 +196,7 @@ Status GetRetDataTypes(const std::vector& ret_nodes, TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "T", &dtype)); (*send_from_host_dtypes)[index] = dtype; } - for (int i = 0; i < send_from_host_dtypes->size(); i++) { + for (int i = 0, end = send_from_host_dtypes->size(); i < end; i++) { if ((*send_from_host_dtypes)[i] == DT_INVALID) { return errors::Internal("Cannot get datatype for output ", i); } @@ -226,7 +226,8 @@ xla::StatusOr BuildSendFromHostNode( for (auto* n : ret_nodes) { int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); - if (index < 0 || index >= send_from_host_dtypes.size()) { + const int num_dtypes = send_from_host_dtypes.size(); + if (index < 0 || index >= num_dtypes) { return errors::Internal("Invalid _Retval index: ", index); } for (auto edge : n->in_edges()) { @@ -361,7 +362,8 @@ xla::StatusOr BuildXlaHostComputeNodeDef( continue; } - if (e->dst_input() < 0 || e->dst_input() >= input_dtypes.size()) { + const int input_dtypes_size = input_dtypes.size(); + if (e->dst_input() < 0 || e->dst_input() >= input_dtypes_size) { return errors::Internal("Invalid dst_input: ", e->dst_input()); } inputs[e->dst_input()] = NodeDefBuilder::NodeOut{ @@ -500,7 +502,7 @@ void AddEdgesFromOutsideCompilationNodes( const std::vector& data_types, const std::vector& outside_compilation_nodes, Graph* g, Node* n) { // Add edges from outside compilation nodes to While node. - for (int i = original_arg_count; i < data_types.size(); i++) { + for (int i = original_arg_count, end = data_types.size(); i < end; i++) { Node* outside_compilation_node = outside_compilation_nodes[i - original_arg_count]; g->AddEdge(outside_compilation_node, 0, n, i + arg_to_input_edge_offset); @@ -619,7 +621,7 @@ Status PostprocessLiftedArgsForWhile( lifted_arg_nodes_and_outside_compilation_nodes.end(), std::back_inserter(lifted_arg_nodes), [](const std::pair& pair) { return pair.first; }); - for (int i = original_arg_count; i < data_types.size(); i++) { + for (int i = original_arg_count, end = data_types.size(); i < end; i++) { TF_ASSIGN_OR_RETURN(Node * arg_node, AddOutsideCompilationInputArgToFunctionBody( *body_function_body, i, data_types[i])); @@ -648,7 +650,7 @@ Status PostprocessLiftedArgsForWhile( AttrSlice(&cond_func.attr()), fld, &cond_function_body)); - for (int i = original_arg_count; i < data_types.size(); i++) { + for (int i = original_arg_count, end = data_types.size(); i < end; i++) { xla::StatusOr arg_node_or = AddOutsideCompilationInputArgToFunctionBody(*cond_function_body, i, data_types[i]); @@ -759,7 +761,7 @@ Status PostprocessLiftedArgsForIf( data_types, outside_compilation_nodes, g, n); - for (int i = original_arg_count; i < data_types.size(); ++i) { + for (int i = original_arg_count, end = data_types.size(); i < end; ++i) { TF_ASSIGN_OR_RETURN(Node * then_branch_arg_node, AddOutsideCompilationInputArgToFunctionBody( *then_branch_function_body, i, data_types[i])); @@ -837,7 +839,7 @@ Status PostprocessLiftedArgsForCall( lifted_arg_nodes_and_outside_compilation_nodes.end(), std::back_inserter(lifted_arg_nodes), [](const std::pair& pair) { return pair.first; }); - for (int i = original_arg_count; i < data_types.size(); ++i) { + for (int i = original_arg_count, end = data_types.size(); i < end; ++i) { TF_ASSIGN_OR_RETURN( Node * arg_node, AddOutsideCompilationInputArgToFunctionBody(*fbody, i, data_types[i])); @@ -855,7 +857,7 @@ Status PostprocessLiftedArgsForCall( // We need to recreate the node. Otherwise TF will not know n->num_inputs() // has increased. NodeDef node_def = n->def(); - for (int i = original_arg_count; i < data_types.size(); i++) { + for (int i = original_arg_count, end = data_types.size(); i < end; i++) { Node* outside_compilation_node = lifted_arg_nodes_and_outside_compilation_nodes[i - original_arg_count] .second; @@ -1804,7 +1806,9 @@ TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode( continue; } - TF_RET_CHECK(e->dst_input() >= 0 && e->dst_input() < inputs.size()); + const bool input_size_check = + e->dst_input() < static_cast(inputs.size()); + TF_RET_CHECK(e->dst_input() >= 0 && input_size_check); inputs[e->dst_input()] = NodeDefBuilder::NodeOut{e->src()->name(), e->src_output(), e->src()->output_type(e->src_output())}; @@ -2420,6 +2424,7 @@ Status ExtractOutsideCompilationForFunction( auto updated_fdef = absl::make_unique(); TF_RETURN_IF_ERROR( GraphToFunctionDef(*g, new_func_name, updated_fdef.get())); + updated_fdef->mutable_signature()->set_is_stateful(true); const FunctionDef* original_fdef = fld->Find(func_name); if (original_fdef) { for (const auto& attr : original_fdef->attr()) { diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index a6f2bd41275..b727dfc72fc 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -422,19 +422,6 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, NoHostGraph) { EXPECT_EQ(fld.Find("host_graph"), nullptr); } -REGISTER_OP("XlaSendToHost") - .Input("input: Tinput") - .Attr("Tinput: type") - .Attr("key: string") - .SetIsStateful(); - -REGISTER_OP("XlaRecvFromHost") - .Output("output: Toutput") - .Attr("Toutput: type") - .Attr("shape: shape") - .Attr("key: string") - .SetIsStateful(); - TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInIf) { // Build the XLA computation func. // "const0" (bool) diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index d1301a8c40f..ff085c854c6 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -268,10 +268,4 @@ void AppendMarkForCompilationPassFlags(std::vector* flag_list) { AppendMarkForCompilationPassFlagsInternal(flag_list); } -static bool xla_is_enabled = false; - -void SetXlaIsEnabled() { xla_is_enabled = true; } - -bool IsXlaEnabled() { return xla_is_enabled; } - } // namespace tensorflow diff --git a/tensorflow/compiler/jit/flags.h b/tensorflow/compiler/jit/flags.h index 89e20d9f8ea..6c54fc8825e 100644 --- a/tensorflow/compiler/jit/flags.h +++ b/tensorflow/compiler/jit/flags.h @@ -162,14 +162,6 @@ MlirCommonFlags* GetMlirCommonFlags(); void AppendMarkForCompilationPassFlags( std::vector* flag_list); -// Makes all future calls to `IsXlaEnabled()` return `true`. -// -// Should only be called when XLA is linked in. -void SetXlaIsEnabled(); - -// Returns whether XLA is enabled. -bool IsXlaEnabled(); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_ diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 6c5e3a745e2..416e101a025 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -461,7 +461,7 @@ string GraphCycles::DebugString() const { } string result = "digraph {\n"; - for (int i = 0; i < rep_->nodes_.size(); i++) { + for (int i = 0, end = rep_->nodes_.size(); i < end; i++) { if (free_nodes_set.contains(i)) { continue; } diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index 23931a0d7cd..bf9d88b73fa 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -194,7 +194,7 @@ Status ComputeSliceSize(const Scope& host_scope, ConstantCache constant_pool(host_scope, control_deps); std::vector slice_size; - for (int i = 0; i < slice_inputs.size_as_vector.size(); i++) { + for (int i = 0, end = slice_inputs.size_as_vector.size(); i < end; i++) { if (slice_inputs.size_as_vector[i] >= 0) { slice_size.push_back( constant_pool.Get1DHostConstant(slice_inputs.size_as_vector[i])); diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 347bae087df..eb9ad8a2e85 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -21,6 +21,7 @@ XLA_OPS_DEPS = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 48347a2915f..9cee4b9af28 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_activity_listener.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -63,38 +64,6 @@ namespace tensorflow { namespace { -XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) { - DeviceType device_type = ctx->device_type(); - se::Platform::Id platform_id = nullptr; - const XlaDevice::Metadata* xla_device_metadata = nullptr; - se::DeviceMemoryAllocator* custom_allocator = nullptr; - - if (ctx->device_type() == DeviceType(DEVICE_CPU)) { - platform_id = se::host::kHostPlatformId; - } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { - platform_id = ctx->device() - ->tensorflow_gpu_device_info() - ->stream->parent() - ->platform() - ->id(); - } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { - // If we are on an XlaDevice, use the underlying XLA platform's allocator - // directly. We could use the StreamExecutor's allocator which may - // theoretically be more correct, but XLA returns a nice OOM message in a - // Status and StreamExecutor does not. - // - // Importantly we can't use ctx->device()->GetAllocator() as the allocator - // (which xla_allocator above uses) as on an XlaDevice, this is a dummy - // allocator that returns XlaTensor objects. The XlaCompiler needs a real - // allocator to allocate real buffers. - platform_id = xla_device_metadata->platform()->id(); - custom_allocator = - xla_device_metadata->client()->backend().memory_allocator(); - } - - return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, - custom_allocator); -} // A closure describing how to run a compiled version of a TensorFlow function. // @@ -178,31 +147,6 @@ class XlaExecutableClosureStore { TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); }; -// Return allocator from platform info if non-null, or populate and return a -// pointer to the allocator adapter with allocator from context. -// -// This is necessary because for XLA devices the underlying TF allocator returns -// dummy tensors. -se::DeviceMemoryAllocator* GetAllocator( - absl::optional* tf_allocator_adapter, - OpKernelContext* ctx, const XlaPlatformInfo& platform_info) { - if (platform_info.custom_allocator()) { - return platform_info.custom_allocator(); - } - if (!ctx->op_device_context()) { - // Stream is not set for the host platform. - se::Platform* platform = - se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) - .ValueOrDie(); - tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform); - return &tf_allocator_adapter->value(); - } - // platform_info. - tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), - ctx->op_device_context()->stream()); - return &tf_allocator_adapter->value(); -} - } // namespace XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, @@ -214,70 +158,15 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, constants_(constants), resources_(resources), function_(function), - platform_info_(PlatformInfoFromContext(ctx)), + platform_info_(XlaPlatformInfoFromContext(ctx)), has_ref_vars_(has_ref_vars) {} -static Status BuildCompilationCache(OpKernelContext* ctx, - const XlaPlatformInfo& platform_info, - XlaCompilationCache** cache) { - if (platform_info.xla_device_metadata()) { - *cache = new XlaCompilationCache( - platform_info.xla_device_metadata()->client(), - platform_info.xla_device_metadata()->jit_device_type()); - return Status::OK(); - } - - auto platform = - se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); - if (!platform.ok()) { - return platform.status(); - } - - xla::StatusOr compiler_for_platform = - xla::Compiler::GetForPlatform(platform.ValueOrDie()); - if (!compiler_for_platform.ok()) { - // In some rare cases (usually in unit tests with very small clusters) we - // may end up transforming an XLA cluster with at least one GPU operation - // (which would normally force the cluster to be compiled using XLA:GPU) - // into an XLA cluster with no GPU operations (i.e. containing only CPU - // operations). Such a cluster can fail compilation (in way that - // MarkForCompilation could not have detected) if the CPU JIT is not linked - // in. - // - // So bail out of _XlaCompile in this case, and let the executor handle the - // situation for us. - const Status& status = compiler_for_platform.status(); - if (status.code() == error::NOT_FOUND) { - return errors::Unimplemented("Could not find compiler for platform ", - platform.ValueOrDie()->Name(), ": ", - status.ToString()); - } - } - - xla::LocalClientOptions client_options; - client_options.set_platform(platform.ValueOrDie()); - client_options.set_intra_op_parallelism_threads( - ctx->device()->tensorflow_cpu_worker_threads()->num_threads); - auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); - if (!client.ok()) { - return client.status(); - } - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(), - ®istration)) { - return errors::InvalidArgument("No JIT device registered for ", - platform_info.device_type().type()); - } - *cache = new XlaCompilationCache( - client.ValueOrDie(), DeviceType(registration->compilation_device_name)); - return Status::OK(); -} - static Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, absl::Span variable_infos, - absl::Span constants, bool lazy, xla::LocalClient** client, + absl::Span constants, bool lazy, bool may_alias_resource_update, + xla::LocalClient** client, const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable) { // We store information about the JIT-compiled XLA computation @@ -291,7 +180,7 @@ static Status CompileToLocalExecutable( TF_RETURN_IF_ERROR(rm->LookupOrCreate( rm->default_container(), "xla_cache", &cache, [&](XlaCompilationCache** cache) { - return BuildCompilationCache(ctx, platform_info, cache); + return BuildXlaCompilationCache(ctx, platform_info, cache); })); // Hold the reference to the JIT during evaluation. (We could probably // free it sooner because the ResourceMgr will retain a reference, but @@ -301,37 +190,22 @@ static Status CompileToLocalExecutable( *client = static_cast(cache->client()); absl::optional tf_allocator_adapter; - XlaCompiler::Options options; - options.client = *client; - if (ctx->op_device_context() != nullptr) { - options.device_ordinal = - ctx->op_device_context()->stream()->parent()->device_ordinal(); - } - options.device_type = cache->device_type(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.graph_def_version = ctx->function_library()->graph_def_version(); - options.allow_cpu_custom_calls = - (platform_info.platform_id() == se::host::kHostPlatformId); - options.device_allocator = - GetAllocator(&tf_allocator_adapter, ctx, platform_info); - if (platform_info.xla_device_metadata()) { - options.shape_representation_fn = - platform_info.xla_device_metadata()->shape_representation_fn(); - } - // If reference variables are not present in the graph, we can safely alias - // passthrough parameters without performing a copy. - options.alias_passthrough_params = - !has_ref_vars && !platform_info.is_on_xla_device(); + XlaCompiler::Options options = GenerateCompilerOptions( + cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter); std::map constant_args; for (int i : constants) { constant_args.insert({i, ctx->input(i)}); } + XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; // Optimization: where possible, have the computation return a naked array // rather than a one-element tuple. compile_options.always_return_tuple = false; + compile_options.alias_resource_update = !has_ref_vars && + !platform_info.is_on_xla_device() && + may_alias_resource_update; std::vector args; TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( @@ -350,20 +224,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* compilation_result; xla::LocalExecutable* executable; - ResourceVarsSnapshot variables_snapshot; + std::vector variable_infos; { - std::vector variable_infos; OP_REQUIRES_OK( ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); Status s = CompileToLocalExecutable( ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, - variable_infos, constants_, /*lazy=*/false, &client, - &compilation_result, &executable); + variable_infos, constants_, /*lazy=*/false, + /*may_alias_resource_update=*/true, &client, &compilation_result, + &executable); OP_REQUIRES_OK(ctx, s); - OP_REQUIRES_OK(ctx, - SnapshotResourceVariables(ctx, resources_, variable_infos, - &variables_snapshot)); + } + + std::map resource_var_ptrs; + for (int i = 0; i < resources_.size(); i++) { + resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor(); } se::Stream* stream = @@ -374,12 +250,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { absl::optional tf_allocator_adapter; se::DeviceMemoryAllocator* allocator = GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + int device_ordinal = stream ? stream->parent()->device_ordinal() + : client->default_device_ordinal(); XlaComputationLaunchContext launch_context( - client, allocator, + client, allocator, device_ordinal, /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), platform_info_.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, compilation_result, variables_snapshot, - /*missing_ctx_input_prefix=*/0); + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); + xla::StatusOr> execution_inputs = + launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs, + /*missing_ctx_input_prefix=*/0, + input_output_alias); + OP_REQUIRES_OK(ctx, execution_inputs.status()); // Execute the computation. VLOG(2) << "Executing computation."; @@ -403,24 +286,24 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { Env* env = Env::Default(); auto start_time = env->NowMicros(); - xla::StatusOr run_result; + xla::StatusOr execution_output; if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) { - run_result = executable->Run(launch_context.arguments(), run_options); + execution_output = + executable->Run(std::move(*execution_inputs), run_options); } else { - run_result = executable->RunAsync(launch_context.arguments(), run_options); + execution_output = + executable->RunAsync(std::move(*execution_inputs), run_options); } - OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + OP_REQUIRES(ctx, execution_output.ok(), execution_output.status()); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; + OP_REQUIRES_OK( + ctx, launch_context.PopulateOutputs( + ctx, compilation_result, execution_output->ConsumeResult(), + /*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos), + input_output_alias, resource_var_ptrs)); - const xla::HloInputOutputAliasConfig& input_output_alias = - executable->executable()->module().input_output_alias_config(); - OP_REQUIRES_OK(ctx, - launch_context.PopulateOutputs( - ctx, compilation_result, run_result.ConsumeValueOrDie(), - /*missing_ctx_input_prefix=*/0, input_output_alias, - variables_snapshot)); VLOG(1) << "Done"; } @@ -490,7 +373,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) constants_(ConstantsVector(ctx)), resources_(ResourcesVector(ctx)), function_(FunctionAttr(ctx)), - platform_info_(PlatformInfoFromContext(ctx)), + platform_info_(XlaPlatformInfoFromContext(ctx)), must_compile_(MustCompileAttr(ctx)), has_ref_vars_(HasRefVars(ctx)) {} @@ -516,10 +399,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK( ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); + + // Do not alias resource updates as locking variables in XlaCompile and + // unlocking them in XlaRun may lead to deadlocks. Status status = CompileToLocalExecutable( ctx, function_, has_ref_vars_, platform_info_, variable_infos, constants_, - /*lazy=*/!must_compile_, &client, &kernel, &executable); + /*lazy=*/!must_compile_, + /*may_alias_resource_update=*/false, &client, &kernel, &executable); OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_, variable_infos, &variables)); if (must_compile_ || status.code() != error::UNIMPLEMENTED) { @@ -574,7 +461,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { } XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) - : OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {} + : OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {} void XlaRunOp::Compute(OpKernelContext* ctx) { VLOG(3) << "XlaRunOp " << def().name(); @@ -587,14 +474,22 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { absl::optional tf_allocator_adapter; se::DeviceMemoryAllocator* allocator = GetAllocator(&tf_allocator_adapter, ctx, platform_info_); + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + int device_ordinal = stream ? stream->parent()->device_ordinal() + : closure.client()->default_device_ordinal(); XlaComputationLaunchContext launch_context( - closure.client(), allocator, + closure.client(), allocator, device_ordinal, /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), /*use_multiple_streams=*/platform_info_.UseMultipleStreams()); // We're missing the must-be-constant inputs, tell `PopulateInputs` // about this. We don't actually need these inputs because they've // already been baked into the compiled kernel. + const xla::HloInputOutputAliasConfig& input_output_alias = + closure.executable()->executable()->module().input_output_alias_config(); + xla::StatusOr> execution_inputs; + std::map snapshot_ptrs; { tensorflow::profiler::TraceMe hlo_module_activity( [&] { @@ -604,13 +499,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { }, tensorflow::profiler::TraceMeLevel::kInfo); - launch_context.PopulateInputs( - ctx, closure.compilation_result(), closure.resource_var_snapshots(), - /*missing_ctx_input_prefix=*/closure.num_constant_args()); + for (auto& p : closure.resource_var_snapshots()) { + snapshot_ptrs.emplace(p.first, + p.second.has_value() ? &p.second.value() : nullptr); + } + execution_inputs = launch_context.PopulateInputs( + ctx, closure.compilation_result(), snapshot_ptrs, + /*missing_ctx_input_prefix=*/closure.num_constant_args(), + input_output_alias); + OP_REQUIRES_OK(ctx, execution_inputs.status()); } - se::Stream* stream = - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(allocator); @@ -631,21 +530,19 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { Env* env = Env::Default(); auto start_time = env->NowMicros(); - xla::StatusOr run_result; + xla::StatusOr execution_output; if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) { - run_result = - closure.executable()->Run(launch_context.arguments(), run_options); + execution_output = + closure.executable()->Run(std::move(*execution_inputs), run_options); } else { - run_result = - closure.executable()->RunAsync(launch_context.arguments(), run_options); + execution_output = closure.executable()->RunAsync( + std::move(*execution_inputs), run_options); } - OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + OP_REQUIRES(ctx, execution_output.ok(), execution_output.status()); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; - const xla::HloInputOutputAliasConfig& input_output_alias = - closure.executable()->executable()->module().input_output_alias_config(); tensorflow::profiler::TraceMe hlo_module_activity( [&] { @@ -653,12 +550,16 @@ void XlaRunOp::Compute(OpKernelContext* ctx) { }, tensorflow::profiler::TraceMeLevel::kInfo); + xla::StatusOr> variable_infos = GatherVariableInfo( + ctx, *closure.compilation_result(), closure.num_constant_args()); + OP_REQUIRES_OK(ctx, variable_infos.status()); + OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos))); OP_REQUIRES_OK( ctx, launch_context.PopulateOutputs( - ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), + ctx, closure.compilation_result(), execution_output->ConsumeResult(), /*missing_ctx_input_prefix=*/closure.num_constant_args(), - input_output_alias, closure.resource_var_snapshots())); + absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs)); } XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h index 112408226a8..78707c8126d 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.h +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_compilation_cache.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -31,61 +32,6 @@ limitations under the License. namespace tensorflow { -// Holds some information about the platform on which an -// XlaLaunch/_XlaCompile/_XlaRun op must run on. -class XlaPlatformInfo { - public: - XlaPlatformInfo() : device_type_("") {} - XlaPlatformInfo(XlaPlatformInfo&&) = default; - explicit XlaPlatformInfo(const DeviceType device_type, - se::Platform::Id platform_id, - const XlaDevice::Metadata* xla_device_metadata, - se::DeviceMemoryAllocator* device_allocator) - : device_type_(device_type), - platform_id_(platform_id), - xla_device_metadata_(xla_device_metadata), - device_allocator_(device_allocator) {} - - XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; - - bool UseMultipleStreams() const { - return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); - } - - // Non-null only when run on an XLA device. - se::DeviceMemoryAllocator* custom_allocator() const { - return device_allocator_; - } - - DeviceType device_type() const { return device_type_; } - - // This is equal to xla_device_metadata()->platform()->id() if - // xla_device_metadata() is not nullptr. - se::Platform::Id platform_id() const { return platform_id_; } - - // This may be null if the op this XlaPlatformInfo is for was not placed on an - // XLA device. - const XlaDevice::Metadata* xla_device_metadata() const { - return xla_device_metadata_; - } - bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } - - private: - DeviceType device_type_; - se::Platform::Id platform_id_; - - // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the - // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the - // XlaLaunch/_XlaCompile/_XlaRun OpKernel. - const XlaDevice::Metadata* xla_device_metadata_; - - // If the op associated with this XlaPlatformInfo is placed on an XLA device - // then device_allocator_ is the xla::Backend's memory allocator. If the op - // is placed on a regular CPU or GPU device then device_allocator_ is null. - se::DeviceMemoryAllocator* device_allocator_; - - TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); -}; // XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. // The only difference is that it does not require arguments to follow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 55ff57a04c5..19eb61b6f72 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1952,6 +1952,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "ParallelDynamicStitch", "ParameterizedTruncatedNormal", "PartitionedCall", + "PopulationCount", "Qr", "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", @@ -2014,6 +2015,7 @@ absl::flat_hash_set GetKnownXLAAllowlistOp() { "StatefulUniform", "StatefulUniformFullInt", "StatefulUniformInt", + "StatelessCase", "StatelessIf", "StatelessMultinomial", "StatelessRandomNormal", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 3ae72eb514c..e88319bb732 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -1829,7 +1829,7 @@ TEST(XlaCompilationTest, XLALiteAllowlist) { } EXPECT_TRUE(unknow_op.empty()) << "Someone added support for a new TF opeations inside XLA. They must " - "be included in the XLALite allowlist or blacklist:\n" + "be included in the XLALite allowlist or denylist:\n" << absl::StrJoin(unknow_op, "\n"); } } // namespace diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index 72804ff57e4..7f585e70ec4 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -36,7 +36,7 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context, if (!context->RankKnown(handle)) return Status::OK(); std::vector dims(context->Rank(handle)); - for (int32 i = 0; i < dims.size(); ++i) { + for (int32 i = 0, end = dims.size(); i < end; ++i) { dims[i] = context->Value(context->Dim(handle, i)); } return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index b8b11d2c7cd..38c23b7fa25 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -489,7 +489,7 @@ Status GetNodesRelatedToRefVariablesInDirection( /*stable_comparator=*/NodeComparatorName()); } - int old_result_size; + size_t old_result_size; int iterations = 0; const int kMaxIterations = 10 * 1000; diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 62b0c0ab4cf..b1525337dbc 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -97,7 +97,7 @@ bool XlaCompilationCache::Signature::operator==(const Signature& other) const { if (arg_shapes != other.arg_shapes) return false; if (arg_values.size() != other.arg_values.size()) return false; - for (int i = 0; i < arg_values.size(); ++i) { + for (int i = 0, end = arg_values.size(); i < end; ++i) { if (arg_values[i].dtype() != other.arg_values[i].dtype() || arg_values[i].shape() != other.arg_values[i].shape() || arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) { @@ -158,7 +158,7 @@ Status XlaCompilationCache::BuildExecutable( std::vector argument_layouts( result.xla_input_shapes.size()); - for (int i = 0; i < result.xla_input_shapes.size(); ++i) { + for (int i = 0, end = result.xla_input_shapes.size(); i < end; ++i) { argument_layouts[i] = &result.xla_input_shapes[i]; } xla::ExecutableBuildOptions build_options; @@ -224,7 +224,7 @@ static xla::StatusOr> CreateGraph( // Create dummy _Arg nodes. Link these to `node` and also via a control // dependency edge to the _SOURCE node. - for (int64 i = 0; i < args.size(); ++i) { + for (int64 i = 0, end = args.size(); i < end; ++i) { Node* node; string arg_name = absl::StrCat("_arg", i); Status status = @@ -240,7 +240,7 @@ static xla::StatusOr> CreateGraph( } // Similarly with return values, create dummy _Retval nodes fed by `node`. - for (int64 i = 0; i < result_types.size(); ++i) { + for (int64 i = 0, end = result_types.size(); i < end; ++i) { Node* node; string retval_name = absl::StrCat("_retval", i); Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) @@ -271,7 +271,7 @@ Status XlaCompilationCache::CompileSingleOp( auto compile_op = [&](XlaCompiler* compiler, XlaCompiler::CompilationResult* result) { std::vector result_dtypes(ctx->num_outputs()); - for (int i = 0; i < result_dtypes.size(); ++i) { + for (int i = 0, end = result_dtypes.size(); i < end; ++i) { result_dtypes[i] = ctx->expected_output_dtype(i); } @@ -330,7 +330,7 @@ Status XlaCompilationCache::CompileImpl( if (VLOG_IS_ON(2)) { VLOG(2) << "num_inputs=" << args.size(); - for (int i = 0; i < args.size(); i++) { + for (int i = 0, end = args.size(); i < end; i++) { VLOG(3) << i << ": " << args[i].HumanString(); } } diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index afaee614f02..73c512bfa6f 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -41,94 +42,82 @@ static std::vector GetResourceVariableIndices(OpKernelContext* ctx) { } Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, - const XlaDevice::Metadata& metadata, + XlaCompilationCache* cache, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable, const ResourceVarsSnapshot& variable_args) { - xla::LocalClient* client = metadata.client(); + xla::LocalClient* client = static_cast(cache->client()); - // Builds an XLA allocator for the device. XlaComputationLaunchContext launch_context( client, client->backend().memory_allocator(), - /*allocate_xla_tensors=*/true, - /*use_multiple_streams=*/metadata.UseMultipleStreams()); + client->default_device_ordinal(), + /*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr, + platform_info_.xla_device_metadata() + ? platform_info_.xla_device_metadata()->UseMultipleStreams() + : false); - launch_context.PopulateInputs(ctx, result, variable_args, - /*missing_ctx_input_prefix=*/0); + std::map snapshot_ptrs; + for (auto& p : variable_args) { + snapshot_ptrs.emplace(p.first, + p.second.has_value() ? &p.second.value() : nullptr); + } + + const xla::HloInputOutputAliasConfig& input_output_alias = + executable->executable()->module().input_output_alias_config(); + xla::StatusOr> execution_inputs = + launch_context.PopulateInputs(ctx, result, snapshot_ptrs, + /*missing_ctx_input_prefix=*/0, + input_output_alias); + TF_RETURN_IF_ERROR(execution_inputs.status()); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - TF_RET_CHECK(stream); VLOG(2) << "Executing computation: " << name(); - for (const xla::ShapedBuffer* arg : launch_context.arguments()) { - VLOG(2) << name() << ": " << *arg; - } xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(client->backend().memory_allocator()); run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); run_options.set_rng_seed(GetXLARandomSeed()); - xla::StatusOr run_result = - executable->Run(launch_context.arguments(), run_options); + xla::StatusOr run_result = + executable->Run(execution_inputs.ConsumeValueOrDie(), run_options); TF_RETURN_IF_ERROR(run_result.status()); - - const xla::HloInputOutputAliasConfig& input_output_alias = - executable->executable()->module().input_output_alias_config(); + xla::ExecutionOutput execution_output = run_result.ConsumeValueOrDie(); + xla::StatusOr> variable_infos = + GatherVariableInfo(ctx, *result, 0); + TF_RETURN_IF_ERROR(variable_infos.status()); + TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(*variable_infos))); TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( - ctx, result, run_result.ConsumeValueOrDie(), - /*missing_ctx_input_prefix=*/0, input_output_alias, variable_args)); + ctx, result, execution_output.ConsumeResult(), + /*missing_ctx_input_prefix=*/0, absl::MakeSpan(*variable_infos), + input_output_alias, snapshot_ptrs)); return Status::OK(); } -Status XlaCompileOnDemandOp::MustArgumentBeConstant( - const OpKernel* op_kernel, int64 argument_idx, - FunctionLibraryRuntime* flib_runtime, bool* result) { - *result = false; - - // TODO(jmolloy): This could be expensive, so memoize. - std::vector constant_input_indices; - TF_RETURN_IF_ERROR(GetCompileTimeConstInputs( - op_kernel, &constant_input_indices, flib_runtime)); - *result = absl::c_binary_search(constant_input_indices, argument_idx); - return Status::OK(); -} - -// TODO(ycao): Remove the need to call ShouldArgumentBeConstant. Its benefit is -// not clear yet and it causes heavy constant analysis to run twice. -Status XlaCompileOnDemandOp::ShouldArgumentBeConstant( - const OpKernel* op_kernel, int64 argument_idx, - FunctionLibraryRuntime* flib_runtime, bool* result) { - return MustArgumentBeConstant(op_kernel, argument_idx, flib_runtime, result); -} - Status XlaCompileOnDemandOp::Compile( - OpKernelContext* ctx, const XlaDevice::Metadata& metadata, - const XlaCompiler::CompilationResult** result, - ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) { + OpKernelContext* ctx, const XlaCompiler::CompilationResult** result, + XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args, + xla::LocalExecutable** executable) { std::map constant_arguments; + + std::vector constant_input_indices; + TF_RETURN_IF_ERROR(GetCompileTimeConstInputs( + &ctx->op_kernel(), &constant_input_indices, ctx->function_library())); + CHECK(absl::c_is_sorted(constant_input_indices)); + for (int64 i = 0; i < ctx->num_inputs(); ++i) { const Tensor& device_tensor = ctx->input(i); if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) { if (xla_tensor->has_host_tensor()) { - bool should_arg_be_const; - TF_RETURN_IF_ERROR(ShouldArgumentBeConstant(&ctx->op_kernel(), i, - ctx->function_library(), - &should_arg_be_const)); - if (should_arg_be_const) { + if (absl::c_binary_search(constant_input_indices, i)) { constant_arguments[i] = xla_tensor->host_tensor(); } } } - if (constant_arguments.count(i) == 0) { - bool must_argument_be_const; - TF_RETURN_IF_ERROR(MustArgumentBeConstant(&ctx->op_kernel(), i, - ctx->function_library(), - &must_argument_be_const)); - - if (must_argument_be_const) { + if (!constant_arguments.count(i)) { + if (absl::c_binary_search(constant_input_indices, i)) { // Slow path; the argument is not available as a host constant so we // must fetch it synchronously. Tensor host_tensor; @@ -156,24 +145,16 @@ Status XlaCompileOnDemandOp::Compile( ResourceMgr* rm = ctx->resource_manager(); CHECK(rm); - XlaCompilationCache* cache; TF_RETURN_IF_ERROR(rm->LookupOrCreate( - rm->default_container(), "xla_cache", &cache, - [&](XlaCompilationCache** cache) { - *cache = new XlaCompilationCache(metadata.client(), - metadata.jit_device_type()); - return Status::OK(); + rm->default_container(), "xla_cache", cache, + [&](XlaCompilationCache** write_into_cache) { + return BuildXlaCompilationCache(ctx, platform_info_, write_into_cache); })); - // Hold the reference to the JIT during evaluation. (We could probably - // free it sooner because the ResourceMgr will retain a reference, but - // this is more obviously correct.) - core::ScopedUnref cache_ref(cache); - XlaCompiler::Options options; - options.device_type = metadata.jit_device_type(); - options.client = metadata.client(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.shape_representation_fn = metadata.shape_representation_fn(); + absl::optional tf_allocator_adapter; + XlaCompiler::Options options = + GenerateCompilerOptions(*cache, ctx, platform_info_, + /*has_ref_vars=*/true, &tf_allocator_adapter); XlaCompiler::CompileOptions compile_options; compile_options.is_entry_computation = true; @@ -194,19 +175,23 @@ Status XlaCompileOnDemandOp::Compile( constant_arguments, variable_infos, ctx, &args)); } - return cache->CompileSingleOp(options, args, ctx, compile_options, result, - executable); + return (*cache)->CompileSingleOp(options, args, ctx, compile_options, result, + executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { const XlaCompiler::CompilationResult* result; xla::LocalExecutable* executable; - const XlaDevice::Metadata* metadata; - OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata)); ResourceVarsSnapshot variable_args; + XlaCompilationCache* cache; OP_REQUIRES_OK(ctx, - Compile(ctx, *metadata, &result, &variable_args, &executable)); - OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args)); + Compile(ctx, &result, &cache, &variable_args, &executable)); + + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + OP_REQUIRES_OK(ctx, Run(ctx, cache, result, executable, variable_args)); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index cc5f2f1e42f..095d3427d41 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/compiler/jit/xla_platform_info.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/function.h" @@ -35,25 +36,24 @@ namespace tensorflow { // vanilla TensorFlow op as long as the bridge supports it. class XlaCompileOnDemandOp : public OpKernel { public: - explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) + : OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {} void Compute(OpKernelContext* ctx) override; private: XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64 i); - Status ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx, - FunctionLibraryRuntime* flib_runtime, - bool* result); - Status MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx, - FunctionLibraryRuntime* flib_runtime, - bool* result); - Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + Status Compile(OpKernelContext* ctx, const XlaCompiler::CompilationResult** result, + XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable); - Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, + + Status Run(OpKernelContext* ctx, XlaCompilationCache* cache, const XlaCompiler::CompilationResult* result, xla::LocalExecutable* executable, const ResourceVarsSnapshot& variable_args); + + const XlaPlatformInfo platform_info_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index 8126059262b..f0555ae32e5 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -59,11 +59,13 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) { return Status::OK(); })); mutex_lock ml(*variable->mu()); - OP_REQUIRES(context, variable->tensor()->dtype() == dtype_, - errors::InvalidArgument( - "Trying to assign variable with wrong dtype. Expected ", - DataTypeString(variable->tensor()->dtype()), " got ", - DataTypeString(dtype_))); + OP_REQUIRES( + context, + !variable->is_initialized || variable->tensor()->dtype() == dtype_, + errors::InvalidArgument( + "Trying to assign variable with wrong dtype. Expected ", + DataTypeString(variable->tensor()->dtype()), " got ", + DataTypeString(dtype_))); variable->is_initialized = true; *variable->tensor() = value; } diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc deleted file mode 100644 index f720183e196..00000000000 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter. - -#include "absl/memory/memory.h" -#include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/xla_device.h" -#include "tensorflow/compiler/jit/xla_device_ops.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" - -namespace tensorflow { - -const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; -const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; - -constexpr std::array kExecAllTypes = { - {DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, - DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; - -class XlaInterpreterDeviceFactory : public DeviceFactory { - public: - Status ListPhysicalDevices(std::vector* devices) override; - Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector>* devices) override; -}; - -Status XlaInterpreterDeviceFactory::ListPhysicalDevices( - std::vector* devices) { - devices->push_back( - absl::StrCat("/physical_device:", DEVICE_XLA_INTERPRETER, ":0")); - - return Status::OK(); -} - -Status XlaInterpreterDeviceFactory::CreateDevices( - const SessionOptions& session_options, const string& name_prefix, - std::vector>* devices) { - static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels( - DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT); - (void)registrations; - - XlaOpRegistry::DeviceRegistration registration; - registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; - registration.autoclustering_policy = - XlaOpRegistry::AutoclusteringPolicy::kAlways; - registration.cluster_resource_variable_ops_unsafely = true; - registration.cluster_stack_ops = false; - registration.cluster_tensor_array_ops = true; - registration.cluster_stateful_rng_ops = true; - registration.cluster_control_trigger = true; - registration.elide_assert_and_checknumerics = true; - registration.cluster_variant_ops = true; - registration.cluster_slow_ops = true; - registration.cluster_inaccurate_ops = true; - XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER, - registration); - - TF_ASSIGN_OR_RETURN( - auto platform, se::MultiPlatformManager::PlatformWithName("Interpreter")); - - XlaDevice::Options options; - options.platform = platform; - options.device_name_prefix = name_prefix; - options.device_name = DEVICE_XLA_INTERPRETER; - options.device_ordinal = 0; - options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT; - options.use_multiple_streams = false; - devices->push_back(absl::make_unique(session_options, options)); - - return Status::OK(); -} - -// Set priority to be below the default priority (50), so that Interpreter is -// not selected as a high priority device over other default devices. See -// constructor comments for Registrar in -// tensorflow/core/common_runtime/device_factory.h for a list of priority for -// devices. -REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_INTERPRETER, - XlaInterpreterDeviceFactory, 40); - -// Kernel registrations -static bool OpFilter(KernelDef* kdef) { return true; } - -REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp, - kExecAllTypes); -REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp, - kExecAllTypes); -REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes); - -REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes); -REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter); - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 5ca146969e0..3a6345afe9f 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -14,10 +14,62 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/xla_kernel_creator.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/jit/compilability_check_util.h" +#include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/jit/xla_kernel_creator_util.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace { + +// Utility which searches for values in a sorted list by scanning over it once. +// No matter how many times ScanForValue is called, the list is scanned at most +// once. However, if a call to ScanForValue skips over a value, that value is +// not revisited in future calls to ScanForValue, so callers must take +// care to order their calls. +// +// Useful for merging multiple sorted lists in O(n) time. +class SinglePassSearch { + public: + // Creates a SinglePassSearch object that can be used to search in `values`. + // Does not take ownership of `values`. `values` must outlive this. + // `values` must be sorted. + explicit SinglePassSearch(const std::vector* values) + : current_index_(0), values_(values) {} + + // Scans forward in the vector looking for "value", updating the internal + // position in to the vector. + // Returns true iff the vector contains the given value at or after current + // position. + // Not thread-safe. + bool ScanForValue(int value) { + while (current_index_ < values_->size() && + (*values_)[current_index_] <= value) { + if ((*values_)[current_index_] == value) { + current_index_++; + return true; + } + current_index_++; + } + return false; + } + + private: + int current_index_; + const std::vector* values_; +}; + +} // end namespace namespace tensorflow { @@ -27,6 +79,121 @@ bool XlaKernelCreator::CanCreateKernel( return CanCreateXlaKernel(props->node_def); } +static Status CreateXlaKernel(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + std::unique_ptr* kernel) { + if (!CanCreateXlaKernel(node_def)) { + return errors::Internal("Invalid node: ", node_def.ShortDebugString()); + } + + VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString(); + + // Make sure that kernels have been registered on the JIT device. + XlaOpRegistry::RegisterCompilationKernels(); + + // Only check for compilability if the MLIR bridge is not enabled. + if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { + RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; + if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { + std::vector + uncompilable_node_info; + for (const auto& it : uncompilable_nodes_map) { + for (const auto& info : it.second.second) { + uncompilable_node_info.emplace_back(info); + } + } + string message = absl::StrCat( + "Function invoked by the following node is not compilable: ", + SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); + absl::StrAppend(&message, "Uncompilable nodes:"); + for (const auto& node_info : uncompilable_node_info) { + string node_message = absl::StrCat("\n", node_info.name, ": ", + node_info.uncompilable_reason, "\n", + "\tStacktrace:\n"); + for (const auto& stack_frame : node_info.stack_trace) { + absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", + stack_frame.name, stack_frame.function_name); + } + absl::StrAppend(&message, node_message); + } + VLOG(1) << message; + return errors::InvalidArgument(message); + } + } + + // Get function body, constant args, and resource args. + const FunctionBody* fbody = nullptr; + std::vector constant_arg_indices; + std::vector resource_arg_indices; + TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( + flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); + + // Set input and output memory types. + MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); + // These indices are used only for optimization purposes. They allow us + // to loop over constant_arg_indices and resource_arg_indices only once + // while iterating over all the function arguments checking if it is a + // resource or a constant. + // The reason we optimized this code is because functions can have a lot of + // captured arguments. For example, the backward pass of ResNet50 takes in all + // 214 variables and a similar number of activations. + SinglePassSearch constants_search(&constant_arg_indices); + SinglePassSearch resources_search(&resource_arg_indices); + for (size_t i = 0; i < fbody->arg_types.size(); ++i) { + if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { + // Compile-time constants and resource handles are expected to be in + // host memory. + input_memory_types[i] = HOST_MEMORY; + } + } + // One might wonder, about the case where a compile-time constant argument + // (which must be in host memory) is also used as an input into an op, + // e.g. Add, that expects its inputs in device memory. Here is how it + // works now. + // First, what do we mean by "op expects an input in XYZ memory"? + // There are two types of "ops" here: the tf2xla kernel and the HLO + // computation it builds. The tf2xla kernel needs to retrieve the actual + // numeric value of the compile-time constant tensors, so it really expects + // them to be on in host memory. However, for other inputs, it refers to them + // using xla::ComputationDataHandle, which is just a symbolic handle that + // xla::ComputationBuilder assigns. How does this handle gets assigned for + // constant arguments? Even constant arguments get an _Arg node in the graph + // instantiated for Function compilation. The tf2xla kernel for constant _Arg + // nodes takes the constant value, converts it to XlaLiteral, and feeds it + // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This + // constant XlaLiteral is included in the HLO graph, and subsequently, in + // the actual executable, which is copied to the device before being + // executed. Thus, when this executable runs, the constant is available in + // device memory. + + // XlaLaunch kernel keeps all outputs (including constants, which it copies), + // in device memory except for resources. + MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); + for (size_t i = 0; i < fbody->ret_types.size(); ++i) { + if (fbody->ret_types[i] == DT_RESOURCE) { + output_memory_types[i] = HOST_MEMORY; + } + } + + // Create the kernel. + NameAttrList function; + TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); + Device* dev = flr->device(); + Status s; + auto props = std::make_shared( + &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types); + OpKernelConstruction construction(DeviceType(dev->device_type()), dev, + dev->GetAllocator(AllocatorAttributes()), + flr, dev->resource_manager(), props, + input_memory_types, output_memory_types, + flr->graph_def_version(), &s); + + *kernel = absl::make_unique( + &construction, constant_arg_indices, resource_arg_indices, function, + /*has_ref_vars=*/false); + return s; +} + Status XlaKernelCreator::CreateKernel( FunctionLibraryRuntime* flr, const std::shared_ptr& props, @@ -34,19 +201,12 @@ Status XlaKernelCreator::CreateKernel( return CreateXlaKernel(flr, props->node_def, kernel); } -namespace { - -bool RegisterLaunchOpCreator() { +static bool RegisterLaunchOpCreator() { XlaKernelCreator* xla_kernel_creator = new XlaKernelCreator(); RegisterDefaultCustomKernelCreator(xla_kernel_creator); return true; } static bool register_me = RegisterLaunchOpCreator(); -static bool register_xla = [] { - SetXlaIsEnabled(); - return true; -}(); -} // end namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_kernel_creator_util.cc b/tensorflow/compiler/jit/xla_kernel_creator_util.cc deleted file mode 100644 index 3cc68f2a1a4..00000000000 --- a/tensorflow/compiler/jit/xla_kernel_creator_util.cc +++ /dev/null @@ -1,182 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/jit/xla_kernel_creator_util.h" - -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "tensorflow/compiler/jit/compilability_check_util.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/kernels/xla_ops.h" -#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" -#include "tensorflow/compiler/tf2xla/const_analysis.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/ptr_util.h" - -namespace tensorflow { -namespace { - -// Utility which searches for values in a sorted list by scanning over it once. -// No matter how many times ScanForValue is called, the list is scanned at most -// once. However, if a call to ScanForValue skips over a value, that value is -// not revisited in future calls to ScanForValue, so callers must take -// care to order their calls. -// -// Useful for merging multiple sorted lists in O(n) time. -class SinglePassSearch { - public: - // Creates a SinglePassSearch object that can be used to search in `values`. - // Does not take ownership of `values`. `values` must outlive this. - // `values` must be sorted. - explicit SinglePassSearch(const std::vector* values) - : current_index_(0), values_(values) {} - - // Scans forward in the vector looking for "value", updating the internal - // position in to the vector. - // Returns true iff the vector contains the given value at or after current - // position. - // Not thread-safe. - bool ScanForValue(int value) { - while (current_index_ < values_->size() && - (*values_)[current_index_] <= value) { - if ((*values_)[current_index_] == value) { - current_index_++; - return true; - } - current_index_++; - } - return false; - } - - private: - int current_index_; - const std::vector* values_; -}; -} // namespace - -Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def, - std::unique_ptr* kernel) { - if (!CanCreateXlaKernel(node_def)) { - return errors::Internal("Invalid node: ", node_def.ShortDebugString()); - } - - VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString(); - - // Make sure that kernels have been registered on the JIT device. - XlaOpRegistry::RegisterCompilationKernels(); - RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map; - if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) { - std::vector - uncompilable_node_info; - for (const auto& it : uncompilable_nodes_map) { - for (const auto& info : it.second.second) { - uncompilable_node_info.emplace_back(info); - } - } - string message = absl::StrCat( - "Function invoked by the following node is not compilable: ", - SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n"); - absl::StrAppend(&message, "Uncompilable nodes:"); - for (const auto& node_info : uncompilable_node_info) { - string node_message = - absl::StrCat("\n", node_info.name, ": ", - node_info.uncompilable_reason, "\n", "\tStacktrace:\n"); - for (const auto& stack_frame : node_info.stack_trace) { - absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n", - stack_frame.name, stack_frame.function_name); - } - absl::StrAppend(&message, node_message); - } - VLOG(1) << message; - return errors::InvalidArgument(message); - } - - // Get function body, constant args, and resource args. - const FunctionBody* fbody = nullptr; - std::vector constant_arg_indices; - std::vector resource_arg_indices; - TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources( - flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices)); - - // Set input and output memory types. - MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY); - // These indices are used only for optimization purposes. They allow us - // to loop over constant_arg_indices and resource_arg_indices only once - // while iterating over all the function arguments checking if it is a - // resource or a constant. - // The reason we optimized this code is because functions can have a lot of - // captured arguments. For example, the backward pass of ResNet50 takes in all - // 214 variables and a similar number of activations. - SinglePassSearch constants_search(&constant_arg_indices); - SinglePassSearch resources_search(&resource_arg_indices); - for (size_t i = 0; i < fbody->arg_types.size(); ++i) { - if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) { - // Compile-time constants and resource handles are expected to be in - // host memory. - input_memory_types[i] = HOST_MEMORY; - } - } - // One might wonder, about the case where a compile-time constant argument - // (which must be in host memory) is also used as an input into an op, - // e.g. Add, that expects its inputs in device memory. Here is how it - // works now. - // First, what do we mean by "op expects an input in XYZ memory"? - // There are two types of "ops" here: the tf2xla kernel and the HLO - // computation it builds. The tf2xla kernel needs to retrieve the actual - // numeric value of the compile-time constant tensors, so it really expects - // them to be on in host memory. However, for other inputs, it refers to them - // using xla::ComputationDataHandle, which is just a symbolic handle that - // xla::ComputationBuilder assigns. How does this handle gets assigned for - // constant arguments? Even constant arguments get an _Arg node in the graph - // instantiated for Function compilation. The tf2xla kernel for constant _Arg - // nodes takes the constant value, converts it to XlaLiteral, and feeds it - // to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This - // constant XlaLiteral is included in the HLO graph, and subsequently, in - // the actual executable, which is copied to the device before being - // executed. Thus, when this executable runs, the constant is available in - // device memory. - - // XlaLaunch kernel keeps all outputs (including constants, which it copies), - // in device memory except for resources. - MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY); - for (size_t i = 0; i < fbody->ret_types.size(); ++i) { - if (fbody->ret_types[i] == DT_RESOURCE) { - output_memory_types[i] = HOST_MEMORY; - } - } - - // Create the kernel. - NameAttrList function; - TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function)); - Device* dev = flr->device(); - Status s; - auto props = std::make_shared( - &fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types); - OpKernelConstruction construction(DeviceType(dev->device_type()), dev, - dev->GetAllocator(AllocatorAttributes()), - flr, dev->resource_manager(), props, - input_memory_types, output_memory_types, - flr->graph_def_version(), &s); - - *kernel = absl::make_unique( - &construction, constant_arg_indices, resource_arg_indices, function, - /*has_ref_vars=*/false); - return s; -} -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 7f107aaef11..19e2b5a2bb5 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -91,29 +91,19 @@ VariableInfo::~VariableInfo() { Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, absl::Span variable_indices, std::vector* result) { - std::vector resource_handles; - absl::c_transform( - variable_indices, std::back_inserter(resource_handles), - [&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); }); - - std::vector> variables; - Status s = LookupResources(ctx, resource_handles, &variables); - if (!s.ok()) { - errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage); - return s; - } - result->clear(); result->reserve(variable_indices.size()); - for (int i = 0; i < variable_indices.size(); i++) { - // *Release* the variable because we're going to unref it later in - // ~VariableInfo. - Var* variable = variables[i].release(); - int input_idx = variable_indices[i]; - std::string var_name = HandleFromInput(ctx, input_idx).name(); - result->emplace_back(input_idx, var_name, variable); + for (int var_idx : variable_indices) { + Var* variable = nullptr; + ResourceHandle handle = HandleFromInput(ctx, var_idx); + TF_RETURN_IF_ERROR( + LookupOrCreateResource(ctx, handle, &variable, [&](Var** ptr) { + // This var is uninitialized for now. + *ptr = new Var(DT_INVALID); + return Status::OK(); + })); + result->emplace_back(var_idx, handle.name(), variable); } - return Status::OK(); } @@ -166,7 +156,7 @@ Status SnapshotResourceVariables(OpKernelContext* ctx, absl::Span variable_indices, absl::Span variable_infos, ResourceVarsSnapshot* result) { - for (int i = 0; i < variable_indices.size(); i++) { + for (int i = 0, end = variable_indices.size(); i < end; i++) { Var* var = variable_infos[i].var(); (*result)[variable_indices[i]] = var ? absl::make_optional(*var->tensor()) : absl::nullopt; @@ -176,35 +166,73 @@ Status SnapshotResourceVariables(OpKernelContext* ctx, XlaComputationLaunchContext::XlaComputationLaunchContext( xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors, bool use_multiple_streams) + int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams) : client_(client), xla_allocator_(xla_allocator), allocate_xla_tensors_(allocate_xla_tensors), - use_multiple_streams_(use_multiple_streams) { + use_multiple_streams_(use_multiple_streams), + device_ordinal_(device_ordinal) { if (use_multiple_streams_) { CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must " "be allocating XLA tensors!"; } } -void XlaComputationLaunchContext::PopulateInputs( +// Fills in `execution_input` with `buffer` for `index`. +static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input, + xla::ShapeIndex index, + se::DeviceMemoryBase& buffer, + bool donate_buffer, int device_ordinal, + se::DeviceMemoryAllocator* allocator) { + xla::MaybeOwningDeviceMemory* in_buffer = + execution_input.MutableBuffer(index); + if (donate_buffer) { + *in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator); + buffer = se::DeviceMemoryBase(); + } else { + *in_buffer = buffer; + } +} + +xla::StatusOr> +XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, - const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) { - // Build ShapedBuffers that point directly to the Tensor buffers. - arg_ptrs_ = - std::vector(compilation_result->xla_input_shapes.size()); + const std::map& resource_vars, + int missing_ctx_input_prefix, + const xla::HloInputOutputAliasConfig& input_output_alias) { + std::vector arguments; + arguments.reserve(compilation_result->xla_input_shapes.size()); xla::TransferManager* transfer_manager = client_->backend().transfer_manager(); - for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) { + for (int i = 0, end = compilation_result->xla_input_shapes.size(); i < end; + ++i) { int arg_num = compilation_result->input_mapping[i]; CHECK_GE(arg_num, missing_ctx_input_prefix); const xla::Shape& shape = compilation_result->xla_input_shapes[i]; - const Tensor* t = variables.count(arg_num) - ? &(variables.at(arg_num).value()) + const xla::Shape& device_shape = + transfer_manager->HostShapeToDeviceShape(shape); + + bool is_resource_variable = resource_vars.count(arg_num); + bool is_updated_resource_variable = + is_resource_variable && + absl::c_any_of(compilation_result->resource_updates, + [&](const XlaCompiler::ResourceUpdate& update) { + return update.input_index == i && update.modified; + }); + + const Tensor* t = is_resource_variable + ? resource_vars.at(arg_num) : &(ctx->input(arg_num - missing_ctx_input_prefix)); CHECK(t); + bool donate_buffer = + t->RefCountIsOne() && is_updated_resource_variable && + input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{}); + VLOG(3) << "Processing input: " << i + << "; is_resource_variable=" << is_resource_variable + << "; is_updated_resource_variable=" << is_updated_resource_variable + << "; donate_buffer=" << donate_buffer; if (use_multiple_streams_) { CHECK(ctx->op_device_context() && ctx->op_device_context()->stream()) @@ -215,23 +243,28 @@ void XlaComputationLaunchContext::PopulateInputs( ctx->op_device_context()->stream()); } - if (xla::Shape::Equal().MinorToMajorOnlyInLayout()( - shape, transfer_manager->HostShapeToDeviceShape(shape))) { + arguments.emplace_back(device_shape, shape); + xla::ExecutionInput& execution_input = arguments.back(); + if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) { se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t); - arg_buffers_.emplace_back( - /*on_host_shape=*/shape, /*on_device_shape=*/shape, - client_->platform(), client_->default_device_ordinal()); - arg_buffers_.back().set_buffer(dmem, /*index=*/{}); - arg_ptrs_[i] = &arg_buffers_.back(); + PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem, + donate_buffer, device_ordinal_, + xla_allocator_); } else { - const XlaTensor* xla_tensor = XlaTensor::FromTensor(t); + XlaTensor* xla_tensor = XlaTensor::FromTensor(t); CHECK(xla_tensor && xla_tensor->has_shaped_buffer()); - arg_ptrs_[i] = const_cast(&xla_tensor->shaped_buffer()); + xla_tensor->shaped_buffer().buffers().ForEachMutableElement( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + PopulateExecutionInputBuffer(execution_input, index, *buffer, + donate_buffer, device_ordinal_, + xla_allocator_); + }); } } + return std::move(arguments); } -// Construct the tensor for given type and buffer. +// Construct the tensor for the given type and buffer. static Tensor MakeTensor(DataType dtype, const TensorShape& shape, se::DeviceMemoryBase buffer, Allocator* allocator) { size_t expected_size = shape.num_elements() * DataTypeSize(dtype); @@ -247,28 +280,26 @@ static Tensor GetOrCreateTensorForOutput( int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix, const xla::HloInputOutputAliasConfig& input_output_alias, absl::Span input_mapping, - const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype, - const TensorShape& output_shape, se::DeviceMemoryBase output_buffer, - Allocator* output_allocator) { + const std::map& resource_vars_snapshots, + DataType output_dtype, const TensorShape& output_shape, + se::DeviceMemoryBase output_buffer, Allocator* output_allocator) { xla::ShapeIndex output_index = input_output_alias.shape().IsTuple() ? xla::ShapeIndex({output_num}) : xla::ShapeIndex({}); + CHECK(input_output_alias.shape().IsTuple() || output_num == 0); if (absl::optional alias = input_output_alias.GetAliasedParameter(output_index)) { + VLOG(3) << "Found alias: " << alias->ToString(); int tf_param = input_mapping[alias->parameter_number] - missing_ctx_input_prefix; - const Tensor* input_tensor = &ctx->input(tf_param); - - // If input tensor is a resource variable, alias to the snapshot we took at - // entry time. - if (input_tensor->dtype() == DT_RESOURCE) { - const absl::optional& v = - resource_var_snapshots.at(missing_ctx_input_prefix + tf_param); - CHECK(v.has_value()); - return *v; + const Tensor input_tensor = + ctx->input(tf_param).dtype() != DT_RESOURCE + ? ctx->input(tf_param) + : *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param); + if (output_buffer.opaque() == input_tensor.data()) { + return input_tensor; } - return *input_tensor; } return MakeTensor(output_dtype, output_shape, output_buffer, output_allocator); @@ -291,12 +322,10 @@ static Status SetOutputForConstant( OpKernelContext* ctx, se::Stream* stream, const XlaCompiler::CompilationResult* compilation_result, int output_num) { CHECK(compilation_result->outputs[output_num].is_constant); - // Output is a constant. const Tensor& const_tensor = compilation_result->outputs[output_num].constant_value; Tensor* output_tensor; - const size_t total_bytes = const_tensor.TotalBytes(); - if (stream && total_bytes > 0) { + if (stream && const_tensor.TotalBytes() > 0) { // Copy host -> device. (Empty tensors don't have backing buffers.) // Manually allocate memory using an XlaTensorBuffer so we can allocate // as much memory as the device requires (as given by @@ -335,52 +364,55 @@ static Status SetOutputForConstant( return Status::OK(); } -// Creates a list of updates resource variables. -static xla::StatusOr> GatherVariableInfo( - OpKernelContext* ctx, - const XlaCompiler::CompilationResult* compilation_result, - int missing_ctx_input_prefix) { - std::vector variable_infos; - variable_infos.reserve(compilation_result->resource_updates.size()); +static xla::StatusOr GetOrCreateResourceVar( + OpKernelContext* ctx, const ResourceHandle& handle, + const XlaCompiler::ResourceUpdate& write) { + Var* variable = nullptr; + TF_RETURN_IF_ERROR( + LookupOrCreateResource(ctx, handle, &variable, [&write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); + return variable; +} - for (int i = 0; i < compilation_result->resource_updates.size(); ++i) { +xla::StatusOr> GatherVariableInfo( + OpKernelContext* ctx, + const XlaCompiler::CompilationResult& compilation_result, + int missing_ctx_input_prefix) { + std::vector out; + out.reserve(compilation_result.resource_updates.size()); + for (int i = 0; i < compilation_result.resource_updates.size(); ++i) { const XlaCompiler::ResourceUpdate& write = - compilation_result->resource_updates[i]; + compilation_result.resource_updates[i]; int actual_input_index = write.input_index - missing_ctx_input_prefix; if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { return errors::Internal("Invalid input index for variable write."); } - // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, - // not a Tensor. - Var* variable = nullptr; const ResourceHandle handle = HandleFromInput(ctx, actual_input_index); - TF_RETURN_IF_ERROR(LookupOrCreateResource(ctx, handle, &variable, - [&write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); - variable_infos.emplace_back(actual_input_index, handle.name(), variable); + TF_ASSIGN_OR_RETURN(Var * variable, + GetOrCreateResourceVar(ctx, handle, write)); + out.emplace_back(actual_input_index, handle.name(), variable); } - return variable_infos; + return std::move(out); } Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, ScopedShapedBuffer output, int missing_ctx_input_prefix, + absl::Span variable_infos, const xla::HloInputOutputAliasConfig& input_output_alias, - const ResourceVarsSnapshot& resource_var_snapshots) { + const std::map& resource_vars) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; Allocator* allocator = ctx->device()->GetAllocator({}); // Computation output should always be a tuple. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString(); - VLOG(2) << "Result tuple shape (on device): " - << output.on_device_shape().DebugString(); - } + VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString(); + VLOG(2) << "Result tuple shape (on device): " + << output.on_device_shape().DebugString(); CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size()); // If the on-host-shape isn't a tuple, create a new single-element tuple @@ -435,11 +467,11 @@ Status XlaComputationLaunchContext::PopulateOutputs( // Copy XLA results to the OpOutputList. int output_num = 0; - for (int i = 0; i < ctx->num_outputs(); ++i) { + for (int i = 0, end = ctx->num_outputs(); i < end; ++i) { const TensorShape& shape = output_tensor_shapes[i]; const DataType& type = compilation_result->outputs[i].type; - VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " - << DataTypeString(type); + VLOG(2) << "Populating output for retval " << i << " shape " + << shape.DebugString() << " type " << DataTypeString(type); if (type == DT_VARIANT) { return errors::Unimplemented( "Support for TensorList crossing the XLA/TF boundary " @@ -467,30 +499,38 @@ Status XlaComputationLaunchContext::PopulateOutputs( se::DeviceMemoryBase buffer = output.buffer({output_num}); Tensor output_tensor = GetOrCreateTensorForOutput( output_num, ctx, missing_ctx_input_prefix, input_output_alias, - compilation_result->input_mapping, resource_var_snapshots, + compilation_result->input_mapping, resource_vars, ctx->expected_output_dtype(i), shape, buffer, allocator); - output.set_buffer(se::OwningDeviceMemory(), {output_num}); ctx->set_output(i, output_tensor); } + output.set_buffer(se::OwningDeviceMemory(), {output_num}); ++output_num; } - - if (VLOG_IS_ON(3)) { - VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString(); - } } - // Apply variable updates, if any. - VLOG(2) << "Applying variable updates"; - TF_ASSIGN_OR_RETURN( - std::vector variable_infos, - GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix)); - TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); + // input_index -> index into variable_infos. + absl::flat_hash_map variable_info_lookup; + for (int i = 0; i < variable_infos.size(); i++) { + variable_info_lookup.emplace(variable_infos[i].index(), i); + } - for (int i = 0; i < compilation_result->resource_updates.size(); ++i) { + // Apply variable updates, if any. + for (int i = 0, end = compilation_result->resource_updates.size(); i < end; + ++i) { const XlaCompiler::ResourceUpdate& write = compilation_result->resource_updates[i]; - if (variable_infos[i].var()->tensor()->dtype() != write.type) { + int actual_input_index = write.input_index - missing_ctx_input_prefix; + CHECK_GE(actual_input_index, 0); + CHECK_LT(actual_input_index, ctx->num_inputs()); + Var* var = variable_infos[variable_info_lookup[actual_input_index]].var(); + CHECK(var); + + VLOG(2) << "Updating variable #" << i + << " at input index: " << actual_input_index << " with shape " + << write.shape.DebugString() << "; variable tensor has shape: " + << var->tensor()->shape().DebugString(); + + if (var->is_initialized && var->tensor()->dtype() != write.type) { return errors::Internal("Mismatched type in variable write"); } @@ -504,21 +544,21 @@ Status XlaComputationLaunchContext::PopulateOutputs( } } else { se::DeviceMemoryBase buffer = output.buffer({output_num}); - output.set_buffer(se::OwningDeviceMemory(), {output_num}); output_tensor = GetOrCreateTensorForOutput( output_num, ctx, missing_ctx_input_prefix, input_output_alias, - compilation_result->input_mapping, resource_var_snapshots, write.type, + compilation_result->input_mapping, resource_vars, write.type, write.shape, buffer, allocator); } - *variable_infos[i].var()->tensor() = output_tensor; - variable_infos[i].var()->is_initialized |= write.modified; + output.set_buffer(se::OwningDeviceMemory(), {output_num}); + var->is_initialized |= write.modified; + *var->tensor() = output_tensor; ++output_num; } return Status::OK(); } Status XlaComputationLaunchContext::BuildXlaCompilerArguments( - const std::map& constant_args, + const std::map& must_be_constant_args, absl::Span variable_args, OpKernelContext* ctx, std::vector* args) { args->resize(ctx->num_inputs()); @@ -534,9 +574,9 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments( for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { XlaCompiler::Argument& arg = (*args)[input_num]; - if (constant_args.count(input_num) > 0) { + if (must_be_constant_args.count(input_num) > 0) { // Handles compile-time constants. - const Tensor& input = constant_args.at(input_num); + const Tensor& input = must_be_constant_args.at(input_num); TF_RET_CHECK(input.dtype() != DT_RESOURCE); arg.kind = XlaCompiler::Argument::kConstant; arg.type = input.dtype(); @@ -562,7 +602,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments( arg.name = std::string(variable.name()); arg.kind = XlaCompiler::Argument::kResource; arg.resource_kind = XlaResource::kVariable; - if (variable.var()) { + if (variable.var() && variable.var()->is_initialized) { const Tensor* value = variable.var()->tensor(); arg.type = value->dtype(); arg.shape = value->shape(); diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 92b6c4c8a08..b34b3059a4f 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -81,6 +81,12 @@ class VariableInfo { bool lock_held_ = false; }; +// Creates a list of updated resource variables. +xla::StatusOr> GatherVariableInfo( + OpKernelContext* ctx, + const XlaCompiler::CompilationResult& compilation_result, + int missing_ctx_input_prefix); + // Takes a snapshot of the values of resource variable arguments, whose indices // are specified in `variable_indices` argument. We snapshot tensors that back // resource variables since concurrent updates may modify the shape, and it is @@ -124,7 +130,7 @@ class XlaComputationLaunchContext { // objects. XlaComputationLaunchContext(xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator, - bool allocate_xla_tensors, + int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams); // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch @@ -142,10 +148,12 @@ class XlaComputationLaunchContext { // missing and adjusts input indices accordingly. All elements in kernel's // input_mapping must be greater than or equal to `missing_ctx_input_prefix` // (in other words, no inputs actually required by the kernel can be missing). - void PopulateInputs(OpKernelContext* ctx, - const XlaCompiler::CompilationResult* compilation_result, - const ResourceVarsSnapshot& variables, - int missing_ctx_input_prefix); + xla::StatusOr> PopulateInputs( + OpKernelContext* ctx, + const XlaCompiler::CompilationResult* compilation_result, + const std::map& resource_vars, + int missing_ctx_input_prefix, + const xla::HloInputOutputAliasConfig& input_output_alias); // Given the XLA output in `output`, populate all outputs of `ctx`. Also // writes out the resource variable updates. @@ -161,20 +169,16 @@ class XlaComputationLaunchContext { OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, + absl::Span variable_infos, const xla::HloInputOutputAliasConfig& input_output_alias, - const ResourceVarsSnapshot& resource_var_snapshots); - - // Return the argument list. Only valid after PopulateInputs() has been - // called. - const std::vector& arguments() const { return arg_ptrs_; } + const std::map& resource_vars); private: xla::LocalClient* client_; se::DeviceMemoryAllocator* xla_allocator_; bool allocate_xla_tensors_; bool use_multiple_streams_; - std::deque arg_buffers_; - std::vector arg_ptrs_; + int device_ordinal_; }; // A simple TensorBuffer implementation that allows us to create Tensors that diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc new file mode 100644 index 00000000000..e2a89353055 --- /dev/null +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -0,0 +1,158 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_platform_info.h" + +#include "tensorflow/compiler/xla/client/client_library.h" + +namespace tensorflow { + +Status BuildXlaCompilationCache(OpKernelContext* ctx, + const XlaPlatformInfo& platform_info, + XlaCompilationCache** cache) { + if (platform_info.xla_device_metadata()) { + *cache = new XlaCompilationCache( + platform_info.xla_device_metadata()->client(), + platform_info.xla_device_metadata()->jit_device_type()); + return Status::OK(); + } + + auto platform = + se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); + if (!platform.ok()) { + return platform.status(); + } + + xla::StatusOr compiler_for_platform = + xla::Compiler::GetForPlatform(platform.ValueOrDie()); + if (!compiler_for_platform.ok()) { + // In some rare cases (usually in unit tests with very small clusters) we + // may end up transforming an XLA cluster with at least one GPU operation + // (which would normally force the cluster to be compiled using XLA:GPU) + // into an XLA cluster with no GPU operations (i.e. containing only CPU + // operations). Such a cluster can fail compilation (in way that + // MarkForCompilation could not have detected) if the CPU JIT is not linked + // in. + // + // So bail out of _XlaCompile in this case, and let the executor handle the + // situation for us. + const Status& status = compiler_for_platform.status(); + if (status.code() == error::NOT_FOUND) { + return errors::Unimplemented("Could not find compiler for platform ", + platform.ValueOrDie()->Name(), ": ", + status.ToString()); + } + } + + xla::LocalClientOptions client_options; + client_options.set_platform(platform.ValueOrDie()); + client_options.set_intra_op_parallelism_threads( + ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); + if (!client.ok()) { + return client.status(); + } + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(), + ®istration)) { + return errors::InvalidArgument("No JIT device registered for ", + platform_info.device_type().type()); + } + *cache = new XlaCompilationCache( + client.ValueOrDie(), DeviceType(registration->compilation_device_name)); + return Status::OK(); +} + +XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) { + DeviceType device_type = ctx->device_type(); + se::Platform::Id platform_id = nullptr; + const XlaDevice::Metadata* xla_device_metadata = nullptr; + se::DeviceMemoryAllocator* custom_allocator = nullptr; + + if (ctx->device_type() == DeviceType(DEVICE_CPU)) { + platform_id = se::host::kHostPlatformId; + } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { + platform_id = ctx->device() + ->tensorflow_gpu_device_info() + ->stream->parent() + ->platform() + ->id(); + } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { + // If we are on an XlaDevice, use the underlying XLA platform's allocator + // directly. We could use the StreamExecutor's allocator which may + // theoretically be more correct, but XLA returns a nice OOM message in a + // Status and StreamExecutor does not. + // + // Importantly we can't use ctx->device()->GetAllocator() as the allocator + // (which xla_allocator above uses) as on an XlaDevice, this is a dummy + // allocator that returns XlaTensor objects. The XlaCompiler needs a real + // allocator to allocate real buffers. + platform_id = xla_device_metadata->platform()->id(); + custom_allocator = + xla_device_metadata->client()->backend().memory_allocator(); + } + + return XlaPlatformInfo(device_type, platform_id, xla_device_metadata, + custom_allocator); +} + +se::DeviceMemoryAllocator* GetAllocator( + absl::optional* tf_allocator_adapter, + OpKernelContext* ctx, const XlaPlatformInfo& platform_info) { + if (platform_info.custom_allocator()) { + return platform_info.custom_allocator(); + } + if (!ctx->op_device_context()) { + // Stream is not set for the host platform. + se::Platform* platform = + se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()) + .ValueOrDie(); + tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform); + return &tf_allocator_adapter->value(); + } + tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), + ctx->op_device_context()->stream()); + return &tf_allocator_adapter->value(); +} + +XlaCompiler::Options GenerateCompilerOptions( + XlaCompilationCache* cache, OpKernelContext* ctx, + const XlaPlatformInfo& platform_info, bool has_ref_vars, + absl::optional* tf_allocator_adapter) { + XlaCompiler::Options options; + options.client = static_cast(cache->client()); + if (ctx->op_device_context() != nullptr) { + options.device_ordinal = + ctx->op_device_context()->stream()->parent()->device_ordinal(); + } + options.device_type = cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = + (platform_info.platform_id() == se::host::kHostPlatformId); + options.device_allocator = + GetAllocator(tf_allocator_adapter, ctx, platform_info); + if (platform_info.xla_device_metadata()) { + options.shape_representation_fn = + platform_info.xla_device_metadata()->shape_representation_fn(); + } + // If reference variables are not present in the graph, we can safely alias + // passthrough parameters without performing a copy. + options.alias_passthrough_params = + !has_ref_vars && !platform_info.is_on_xla_device(); + return options; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h new file mode 100644 index 00000000000..dac45529ac9 --- /dev/null +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -0,0 +1,108 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/stream_executor/tf_allocator_adapter.h" + +namespace tensorflow { + +// Holds some information about the platform on which an +// XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of +// abstraction for normal and XLA devices. +class XlaPlatformInfo { + public: + XlaPlatformInfo() : device_type_("") {} + XlaPlatformInfo(XlaPlatformInfo&&) = default; + explicit XlaPlatformInfo(const DeviceType device_type, + se::Platform::Id platform_id, + const XlaDevice::Metadata* xla_device_metadata, + se::DeviceMemoryAllocator* device_allocator) + : device_type_(device_type), + platform_id_(platform_id), + xla_device_metadata_(xla_device_metadata), + device_allocator_(device_allocator) {} + + XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; + + bool UseMultipleStreams() const { + return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); + } + + // Non-null only when run on an XLA device. + se::DeviceMemoryAllocator* custom_allocator() const { + return device_allocator_; + } + + DeviceType device_type() const { return device_type_; } + + // This is equal to xla_device_metadata()->platform()->id() if + // xla_device_metadata() is not nullptr. + se::Platform::Id platform_id() const { return platform_id_; } + + // This may be null if the op this XlaPlatformInfo is for was not placed on an + // XLA device. + const XlaDevice::Metadata* xla_device_metadata() const { + return xla_device_metadata_; + } + bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } + + private: + DeviceType device_type_; + se::Platform::Id platform_id_; + + // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the + // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the + // XlaLaunch/_XlaCompile/_XlaRun OpKernel. + const XlaDevice::Metadata* xla_device_metadata_; + + // If the op associated with this XlaPlatformInfo is placed on an XLA device + // then device_allocator_ is the xla::Backend's memory allocator. If the op + // is placed on a regular CPU or GPU device then device_allocator_ is null. + se::DeviceMemoryAllocator* device_allocator_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); +}; + +// Returns created XLA compilation cache. +Status BuildXlaCompilationCache(OpKernelContext* ctx, + const XlaPlatformInfo& platform_info, + XlaCompilationCache** cache); + +// Returns information about the platform from kernel context. +XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx); + +// Returns allocator from platform info if non-null, or populate and return a +// pointer to the allocator adapter with allocator from context. +// +// This is necessary because for XLA devices the underlying TF allocator returns +// dummy tensors. +se::DeviceMemoryAllocator* GetAllocator( + absl::optional* tf_allocator_adapter, + OpKernelContext* ctx, const XlaPlatformInfo& platform_info); + +// Returns created options for the XLA compiler, and writes the used allocator +// into `tf_allocator_adapter`. +XlaCompiler::Options GenerateCompilerOptions( + XlaCompilationCache* cache, OpKernelContext* ctx, + const XlaPlatformInfo& platform_info, bool has_ref_vars, + absl::optional* tf_allocator_adapter); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 57f923caa91..01c187790b7 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -150,6 +150,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:translate_registration", "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op", "//tensorflow/compiler/mlir/xla:xla_mlir_translate", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", diff --git a/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md b/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md index 2fe109c1783..8e7e605fc4c 100644 --- a/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md +++ b/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md @@ -74,7 +74,6 @@ We have several choices on how to lower the host-side part from LHLO: * (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as TFRT ops are interpreted by C++ code. * (Con) host side is under development and not tested. - * (Con) the JAX integration isn’t clear from a runtime point of view * Jitted CPU code * (Pro) great lower-ability. Create a few loops and conditions and it's done. @@ -84,8 +83,7 @@ We have several choices on how to lower the host-side part from LHLO: dynamic loading, etc). * Existing (interpreting) XLA runtime -Tentative conclusion: Use jitted CPU code during the transition, and optionally -adopt TFRT in the end. +Decision: adopt TFRT, but also support jitting CPU code in TFRT. ## Migrating Device LLVM IR (Task 3) @@ -114,7 +112,7 @@ end state of each XLA op: * (Cost) Will be throw-away work if we want to ultimately migrate to Standard. * (Benefit) It is easy and mechanical. Can be done in a short period. - * (Benefit) It doesn't benefit more compared to a). + * (Benefit) It doesn't benefit more compared to (1). 1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops: * (Cost) Lifting existing emitters to Standard introduces some challenges. Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring @@ -134,6 +132,19 @@ end state of each XLA op: * (Benefit) unified stack; community support; portability; more optimization potentials. +Conclusions: + +* Don't go for (2). (1) or (3) are just better than (2). (2) costs more than + (1), since it requires a lot of mechanical refactoring. With (1) we can + still achieve the goal of enabling XLA to pick up MLIR emitters. This is by + doing LHLO -> LLVM IR -> run legacy device emitters. +* ElementalIrEmitter ops go for (4), but not incrementally. There is no way to + do it op by op, because all elementally-emitted ops are connected into the + same graph. This work can also serve as a unification point of several + on-going forces (xla/service/mlir\_gpu, the kernel generator, Linalg). +* All other ops go for (1). As a stretch goal, they might be migrated to (3) + or (4). + ## Prioritization While all three tasks mentioned above are parallelizable, under limited @@ -210,26 +221,19 @@ The exact profiling can't be easily done for MLIR-generated ops, since: ### Step 3: (Task 2) Migrating Thunks -This step migrates all host ops and library calls. This step will eliminate most -of the thunks and produce serializable MLIR instead. - -There are roughly three kinds of thunks: - +As a note, there are roughly three kinds of thunks: * KernelThunk, which launches a kernel. * Control flow thunks, which has host control flow logic (conditional, while, for, sequence) and launch body kernels. * Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc. -The **bottom line** is to: +The plan is: +* Make Thunks (de)serializable. +* Help improve TFRT to a state where it can support these semantics. +* As the state improves, migrate individual thunks incrementally. -* Create a Thunk dialect that provides (de)serialize logic for all existing - C++-based Thunks. -* Change emitters to emit a graph of Thunk dialect. - -**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk -can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG -Dialect for loops and conditions, combined with LaunchKernelOp. This optional -step requires profiling and stream support. +These action items are only partially ordered. The actual execution order / +engineering parallelism is to be evaluated as it goes. ### Step 4: (Task 3) Migrated ElementalIrEmitter diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index 9f6856f3636..edbf3663a89 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -52,7 +52,7 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties): native.py_test( name = name, srcs = ["@llvm-project//llvm:lit"], - tags = tags + ["no_windows"], + tags = tags + ["no_pip", "no_windows"], args = [ "tensorflow/compiler/mlir/" + paths.basename(data[-1]) + " --config-prefix=runlit -v", ] + features, diff --git a/tensorflow/compiler/mlir/hlo/.gitignore b/tensorflow/compiler/mlir/hlo/.gitignore new file mode 100644 index 00000000000..cc1696bf575 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/.gitignore @@ -0,0 +1,4 @@ +build +llvm-project +llvm-build + diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index c7bda887db0..126d44670a0 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -55,6 +55,38 @@ filegroup( ], ) +gentbl( + name = "MhloPassIncGen", + strip_include_prefix = "include/mlir-hlo/Dialect/mhlo/transforms/", + tbl_outs = [ + ( + "-gen-pass-decls -name MHLO", + "include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td", + td_srcs = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + +gentbl( + name = "LmhloPassIncGen", + strip_include_prefix = "include/mlir-hlo/Dialect/mhlo/transforms/", + tbl_outs = [ + ( + "-gen-pass-decls -name LMHLO", + "include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td", + td_srcs = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) + gentbl( name = "chlo_ops_inc_gen", strip_include_prefix = "include", @@ -76,8 +108,8 @@ gentbl( tbl_outs = [ ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"), ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"), - ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"), - ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"), + ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc"), + ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td", @@ -106,14 +138,36 @@ gentbl( td_srcs = [":hlo_ops_td_files"], ) +gentbl( + name = "hlo_ops_pattern_gen", + strip_include_prefix = "lib/Dialect/mhlo/IR/", + tbl_outs = [ + ( + "-gen-rewriters", + "lib/Dialect/mhlo/IR/hlo_patterns.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "lib/Dialect/mhlo/IR/hlo_patterns.td", + td_relative_includes = [ + "include", + ], + td_srcs = [ + ":hlo_ops_td_files", + "@llvm-project//mlir:StdOpsTdFiles", + "@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeBase.td", + "@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeOps.td", + ], +) + gentbl( name = "lhlo_ops_inc_gen", strip_include_prefix = "include", tbl_outs = [ ("-gen-op-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"), ("-gen-op-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"), - ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc"), - ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc"), + ("-gen-struct-attr-decls", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc"), + ("-gen-struct-attr-defs", "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td", @@ -126,11 +180,12 @@ gentbl( #TODO(aminim): revisit the naming and grouping of these rules post-move. gentbl( name = "canonicalize_inc_gen", + strip_include_prefix = "lib/Dialect/mhlo/IR/", tbl_outs = [ - ("-gen-rewriters", "lib/Dialect/mhlo/transforms/generated_canonicalize.inc"), + ("-gen-rewriters", "lib/Dialect/mhlo/IR/mhlo_canonicalize.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "lib/Dialect/mhlo/transforms/canonicalize.td", + td_file = "lib/Dialect/mhlo/IR/mhlo_canonicalize.td", td_relative_includes = [ "include", ], @@ -146,7 +201,7 @@ gentbl( ), ( "-gen-op-interface-defs", - "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cc.inc", + "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cpp.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", @@ -168,6 +223,7 @@ cc_library( "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h", "include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc", ], + includes = ["include"], deps = [ ":infer_fusibility_op_interface_gen", "@llvm-project//mlir:IR", @@ -180,6 +236,7 @@ cc_library( name = "convert_op_folder", srcs = ["lib/utils/convert_op_folder.cc"], hdrs = ["include/mlir-hlo/utils/convert_op_folder.h"], + includes = ["include"], deps = [ "@llvm-project//mlir:IR", ], @@ -203,13 +260,13 @@ cc_library( ], includes = ["include"], deps = [ + "hlo_ops_pattern_gen", ":canonicalize_inc_gen", ":chlo_ops_inc_gen", ":convert_op_folder", ":hlo_ops_base_inc_gen", ":hlo_ops_inc_gen", ":infer_fusibility_op_interface", - "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", @@ -254,7 +311,7 @@ cc_library( ) cc_library( - name = "hlo_dialect_registration", + name = "hlo_dialect_force_registration", srcs = ["lib/Dialect/mhlo/IR/dialect_registration.cc"], deps = [ ":hlo", @@ -264,6 +321,17 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "hlo_dialect_registration", + srcs = ["lib/Dialect/mhlo/IR/init.cc"], + hdrs = ["include/mlir-hlo/Dialect/mhlo/IR/register.h"], + deps = [ + ":hlo", + ":lhlo", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "sink_constants_to_control_flow", srcs = ["lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc"], @@ -273,6 +341,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], @@ -307,7 +376,6 @@ cc_library( ":hlo", ":lhlo", ":map_lmhlo_to_scalar_op", - "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", @@ -322,7 +390,6 @@ cc_library( srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc"], deps = [ ":lhlo", - "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", @@ -337,6 +404,7 @@ cc_library( cc_library( name = "lhlo_legalize_to_llvm", srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc"], + hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], deps = [ ":lhlo", "@llvm-project//mlir:IR", @@ -357,7 +425,6 @@ cc_library( ":hlo", ":lhlo", ":map_lmhlo_to_scalar_op", - "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:Affine", "@llvm-project//mlir:IR", @@ -375,7 +442,6 @@ cc_library( hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"], deps = [ ":hlo", - "@com_google_absl//absl/memory", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Shape", @@ -392,7 +458,6 @@ cc_library( ":hlo", ":lhlo", ":map_lmhlo_to_scalar_op", - "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", @@ -411,7 +476,6 @@ cc_library( hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ ":lhlo", - "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgTransforms", @@ -429,7 +493,6 @@ cc_library( hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/passes.h"], deps = [ ":lhlo", - "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -450,7 +513,6 @@ cc_library( ":hlo", ":lhlo", ":map_hlo_to_lhlo_op", - "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -465,6 +527,7 @@ cc_library( name = "cycle_detector", srcs = ["lib/utils/cycle_detector.cc"], hdrs = ["include/mlir-hlo/utils/cycle_detector.h"], + includes = ["include"], deps = [ "@llvm-project//llvm:Support", ], @@ -501,13 +564,14 @@ cc_library( gentbl( name = "legalize_to_standard_inc_gen", + strip_include_prefix = "lib/Dialect/mhlo/transforms/", tbl_outs = [ ("-gen-rewriters", "lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td", td_relative_includes = [ - "../hlo/include", + "include", ], td_srcs = [ ":hlo_ops_td_files", @@ -548,6 +612,25 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "legalize_gather_to_torch_index_select", + srcs = ["lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc"], + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", + "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", + ], + deps = [ + ":hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + cc_library( name = "legalize_tanh_to_approximation", srcs = ["lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc"], @@ -555,6 +638,7 @@ cc_library( "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", "include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h", ], + includes = ["include"], deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -568,13 +652,14 @@ cc_library( gentbl( name = "lower_complex_inc_gen", + strip_include_prefix = "lib/Dialect/mhlo/transforms/", tbl_outs = [ ("-gen-rewriters", "lib/Dialect/mhlo/transforms/generated_lower_complex.inc"), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lib/Dialect/mhlo/transforms/lower_complex_patterns.td", td_relative_includes = [ - "../hlo/include", + "include", ], td_srcs = [ ":hlo_ops_td_files", @@ -587,9 +672,9 @@ cc_library( #TODO(aminim): find a better name here? name = "mhlo_to_mhlo_lowering_patterns", srcs = [ - "lib/Dialect/mhlo/transforms/generated_lower_complex.inc", "lib/Dialect/mhlo/transforms/lower_complex.cc", "lib/Dialect/mhlo/transforms/lower_general_dot.cc", + "lib/Dialect/mhlo/transforms/optimize_mhlo.cc", ], hdrs = [ "include/mlir-hlo/Dialect/mhlo/transforms/passes.h", @@ -597,7 +682,8 @@ cc_library( ], deps = [ ":hlo", - ":hlo_dialect_registration", + ":hlo_dialect_force_registration", + ":lower_complex_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", @@ -649,7 +735,9 @@ cc_library( deps = [ ":hlo", "@llvm-project//mlir:IR", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], ) @@ -661,6 +749,7 @@ cc_library( "lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc", "lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc", "lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc", + "lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc", "lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc", "lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc", ], @@ -671,13 +760,12 @@ cc_library( ":lhlo_legalize_to_llvm", # build-cleaner: keep ":materialize_broadcasts", # build-cleaner: keep ":unfuse_batch_norm", # build-cleaner: keep - "@llvm-project//mlir:AffineToStandardTransforms", - "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMTransforms", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Shape", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", @@ -686,15 +774,20 @@ cc_library( ) cc_library( - name = "all_passes_for_testing", + name = "all_passes", + hdrs = [ + "include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h", + ], visibility = [ - "//tensorflow/compiler/mlir:__subpackages__", + ":friends", ], deps = [ + ":LmhloPassIncGen", + ":MhloPassIncGen", ":chlo_legalize_to_hlo", - ":hlo_dialect_registration", ":hlo_legalize_to_lhlo", ":legalize_control_flow", + ":legalize_gather_to_torch_index_select", ":legalize_tanh_to_approximation", ":legalize_to_linalg", ":legalize_to_standard", @@ -709,15 +802,23 @@ cc_library( ":sink_constants_to_control_flow", ":test_passes", ":transform_unranked_hlo", + "@llvm-project//mlir:Pass", ], ) cc_binary( name = "mlir-hlo-opt", + srcs = [ + "tools/mlir-hlo-opt/mlir-hlo-opt.cpp", + ], deps = [ - ":all_passes_for_testing", - "@llvm-project//mlir:AllPassesAndDialects", + ":all_passes", + ":hlo_dialect_registration", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", - "@llvm-project//mlir:MlirOptMain", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/hlo/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/CMakeLists.txt new file mode 100644 index 00000000000..c4e2ea123df --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/CMakeLists.txt @@ -0,0 +1,94 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +cmake_minimum_required(VERSION 3.13.4) + +if(POLICY CMP0068) + cmake_policy(SET CMP0068 NEW) + set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) +endif() + +if(POLICY CMP0075) + cmake_policy(SET CMP0075 NEW) +endif() + +if(POLICY CMP0077) + cmake_policy(SET CMP0077 NEW) +endif() + +#------------------------------------------------------------------------------- +# Project setup and globals +#------------------------------------------------------------------------------- + +project(mlir-hlo LANGUAGES CXX C) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 14) +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") + +#------------------------------------------------------------------------------- +# Options and settings +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# MSVC defaults +#------------------------------------------------------------------------------- + +if(MSVC) + add_compile_options( + $<$:/MD> + $<$:/MD> + $<$:/MD> + ) +endif() + +#------------------------------------------------------------------------------- +# MLIR/LLVM Configuration +#------------------------------------------------------------------------------- + +find_package(MLIR REQUIRED CONFIG) +message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") + +if(LLVM_ENABLE_ZLIB) + find_package(ZLIB) +endif() + +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) +include_directories(${LLVM_INCLUDE_DIRS}) +include_directories(${MLIR_INCLUDE_DIRS}) +include_directories(${PROJECT_SOURCE_DIR}/include) +include_directories(${PROJECT_BINARY_DIR}/include) +include_directories(${PROJECT_BINARY_DIR}/) +link_directories(${LLVM_BUILD_LIBRARY_DIR}) +add_definitions(${LLVM_DEFINITIONS}) + +#------------------------------------------------------------------------------- +# Directory setup +#------------------------------------------------------------------------------- + +set(MLIR_HLO_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(MLIR_HLO_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) + +add_custom_target(check-mlir-hlo) + +add_subdirectory(include/mlir-hlo) +add_subdirectory(lib) +add_subdirectory(tools) +add_subdirectory(tests) diff --git a/tensorflow/compiler/mlir/hlo/README.md b/tensorflow/compiler/mlir/hlo/README.md new file mode 100644 index 00000000000..9eaa14031fd --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/README.md @@ -0,0 +1,233 @@ +# MLIR-HLO: A Standalone "HLO" MLIR-based Compiler + +The code here exists in two places: + +* https://github.com/tensorflow/tensorflow/tree/master/tensorflow/compiler/mlir/hlo; + this is the canonical location and where contributions should be made using + GitHub pull-requests. +* https://github.com/tensorflow/mlir-hlo; this is a standalone repository with + a view to the same code to allow other projects to use this without + depending on the entire TF monorepo. + +This implements a self-contained compiler for a linear algebra set of operations +inspired by XLA +[HLO IR](https://www.tensorflow.org/xla/architecture#how_does_xla_work) using +MLIR components. It is designed to provide an end-to-end flow independent of +TensorFlow and XLA, but usable inside of these projects. + +Coding practice and conventions in this repository follow the +[MLIR Developer Guide](https://mlir.llvm.org/getting_started/DeveloperGuide/) in +this repo as part of the intent to act as an incubator for technology to +upstream. + +## QuickStart: building and testing + +These instructions work on Linux, you may have to adjust for your plaform. + +To build the code in this repository, you need a clone of the LLVM/MLIR git +repository: + + $ git clone https://github.com/llvm/llvm-project.git + + +You need to make sure you have the right commit checked out in the LLVM +repository (you need to do this every time you pull from this repo): + + $ (cd llvm-project && git checkout $(cat build_tools/llvm_version.txt)) + +We provide a script to configure and build LLVM/MLIR: + + $ build_tools/build_mlir.sh ${PWD}/llvm-project/ ${PWD}/llvm-build + +Again this is something to do every time you pull from this repository and the +LLVM revision changes. + +Finally you can build and test this repository: + + $ mkdir build && cd build + $ cmake .. -GNinja \ + -DLLVM_ENABLE_LLD=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=On \ + -DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir + $ ninja check-mlir-hlo + + +## Overview + +MLIR-HLO aims to provide an end-to-end compiler for CPU and GPU, as well as +building reusable blocks for other accelerators. This is heavily inspired by the +success of XLA. + +[XLA](https://www.tensorflow.org/xla/) (Accelerated Linear Algebra) is a +domain-specific compiler framework and execution environment for linear algebra, +which powers code-generation for ML frameworks like TensorFlow, JAX, and others. + +A cornerstone of XLA is the HLO (High Level Optimizer) IR, which offers a +carefully fixed selected list of operations, mostly orthogonal to each other. It +provides an efficient optimizer for computations expressed with this set of +operations and generate codes for hardware platforms like CPU, GPU, and TPUs. +Its goal is to provide a uniform interface to compile and execute these +optimized HLO programs independently of the targeted device. It is not a +front-end ML system like TensorFlow or JAX, rather it is a backend framework +that optimizes HLO and lowers to machine code. + +The HLO set of operations is closed and has well defined semantics. HLO +operations operate on immutable Tensors with static shapes (actually bounded +shapes to be exact) and explicit broadcasts. + +[MLIR](https://mlir.llvm.org/) is a compiler infrastructure which intends to +come with "battery included", as such it intends to provide all the blocks +required to assemble graph optimization and codegen pipelines. The longer term +roadmap for MLIR is to provide a +[Tensor Compute Primitive](https://llvm.discourse.group/c/mlir/MLIR-TCP-WG/36) +(TCP) dialect, which should hopefully be general enough to model what HLO +represents today (see +[slides](https://drive.google.com/open?id=1iljcpTQ5NPaMfGpoPDFml1XkYxjK_6A4) and +[recording](https://drive.google.com/open?id=1jSPa8TwPKUt0WuLquGc8OgSUVYJHMvWZ) +for a technical discussion on this topic). + +The work on MLIR-HLO can be seen as a stepping stone towards building TCP, while +integrating intermediate components into XLA itself by relying on the +well-proven HLO IR and introducing more pieces from upstream MLIR +([Linalg](https://mlir.llvm.org/docs/Dialects/Linalg/), +[Vector](https://mlir.llvm.org/docs/Dialects/Vector/), +[GPU](https://mlir.llvm.org/docs/Dialects/GPU/) dialect, ...). +[This document](https://www.tensorflow.org/mlir/xla_gpu_codegen) provides more +information on the current migration of the XLA GPU codegen. + +## MLIR Dialects for XLA-style compilation + +This repository defines three dialects to support a HLO-like compilation +pipeline using MLIR: + +* `chlo`: the "client" HLO dialect, intended to be closer to the frontend + (including implicit broadcast semantics). +* `mhlo`: "meta"-HLO dialect ; similar to `xla_hlo`, but with extensions for + dynamic shape support. +* `lmhlo`: "late"-"meta"-HLO, it is the IR after buffer allocation is + performed. In XLA the buffer allocation is a side-datastructure which keeps + track of these informations, while this separate dialect materializes it in + the IR. + +We describe these in more details below. + +### HLO Client Dialect: `chlo`. + +* It was originaly designed to map the + [XLA client APIs](https://www.tensorflow.org/xla/operation_semantics) (e.g., + ops supports implicit broadcast and roughly modeled on XlaBuilder API) + modulo support for dynamic shapes and additional ops required to support + dynamic client side HLOs. +* Ops can be from either the XlaBuilder or XLA helper functions can be + converted into ops (e.g., given ambiguity in what constitutes these ops, + there is some freedom to decide), the goal of this dialect is to correspond + close to client level and enable a thin layer between client use and op + construction (making it cheap to construct and optimizations on the dialect + close to optimizations on the client ops). + +Entry: + +* The vast majority of old "client" interactions are via the XlaBuilder APIs. + These APIs are used by TF2XLA kernels, JAX, PyTorch bridge and directly. The + legalization path (described below) can also reuse the XlaBuilder's APIs to + construct XLA Client HLO ops directly (this uses MlirXlaBuilder which is a + subclass of XlaBuilder). +* The other entry point is during legalization from TensorFlow ops in the TF + Graph Compiler and other tools (e.g., SavedModel lowering and TFCompile). + +Exit: + +* MHLO +* May be exported to xla::HloInstructionProto by invoking the XlaBuilder APIs + (with regular XlaBuilder) + +The `chlo` dialect started originally as mapping to the XLA client Builder APIs. +It enables it to both be constructed and converted back to existing XLA +interfaces using the XlaBuilder API. Due to the way that translation into and +out of the dialect works, there is no expectation that this dialect roundtrips +to XLA (e.g., it is only intended to be translated to MLIR and then legalized to +another dialect or translated to HloInstructionProto). + +The export approach of reusing the XlaBuilders enables reusing a lot of logic +that was already implemented in terms of computing shapes, inserting broadcasts +etc. + +An important topic here is that XLA Client HLO ops are not a well defined set. +And in particular what some would consider helper functions, others would +consider ops. It should be easy to move between these and so define a new op +along with the helper function or autogenerate the helper functions from the +descriptions of the ops. For the former, a simple approach would be to simply +consider the context in which the op is being constructed and if an MLIR one, +construct a op in the client dialect instead of further calls into XlaBuilder. +The latter could be implemented by adding the op and a legalization of the op to +other known ops, from which a helper function can get generated that could be +used as regular. + +Status: Exists but need to be cleaned up. + +### Meta HLO Dialect `mhlo` + +* Dialect is closer to current HLO server ops (e.g., no implicit broadcast) +* MHLO dialect where we can deviate from the requirements of the client or + server dialect, in particular: + * Control flow ops with implicit capture to enable simpler optimizations + (e.g., generic LICM, unroll & jam, etc.) + * Multiple results ops (e.g., no tuples) + * More ops (for example, unique op or assert op), and ops that don't need + to be added to either client or server dialect. + * Op set not constrained by implementation (e.g., hlo.add operating on say + i79 or !mydialect.weird_type is allowed even though no XLA backend + supports it). Verification on types happening at the boundaries. + * It does not need to preserve some deprecated XLA constructs (e.g. + stateful RNG HLO). + * More dynamic shape support ops without need for updating all + users/backends. +* This dialect enables evolving HLO independently from XLA in order to + experiment with features we'd like to upstream in MLIR TCP. In particular it + intends to be user-extensible through + [interfaces](https://mlir.llvm.org/docs/Interfaces/). +* It should have no TensorFlow, or proto, or other Google internal + dependencies. +* It need not be a complete superset of ops compared to XLA HLO dialect. + +Entry: + +* Legalization from `chlo` dialect or conversion from XLA HLO. +* Directly emitted from TF Graph Compiler; +* Builder call (e.g., EDSL); + +Exit: + +* LMHLO, Linalg IREE, directly used in codegen. +* XLA HLO. + +The MHLO dialect has no direct export format, it is only meant as an +intermediate optimization dialect/format. It is also where we can experiment +cheaply with new ops. This format will be where the representation would differ +from existing end points. + +Status: Exists but need to be cleaned up and evolved, in particular with respect +to supporting dynamic shapes. + +### LMHLO + +LMHLO corresponds to late `mhlo` and operates on buffer domain (e.g., memref) +with side-effecting operations. The lowering from `mhlo` dialect proceeds by way +of scheduling, memory and buffer allocation. The current mapping is directly on +XLA Client HLOs but without implicit broadcast and with operation on memrefs. +This dialect will instead be rebased on `mhlo` dialect but operating on buffers +still. + +Entry: + +* Post buffer assignment on `mhlo` dialect, or from XLA after buffer + assignment. + +Exit: + +* Codegen (LLVM IR in the common cases at the moment) + +## End-to-End pipeline + +TODO diff --git a/tensorflow/compiler/mlir/hlo/build_tools/build_mlir.sh b/tensorflow/compiler/mlir/hlo/build_tools/build_mlir.sh new file mode 100755 index 00000000000..5ccefb9416f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/build_tools/build_mlir.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +if [[ $# -ne 2 ]] ; then + echo "Usage: $0 " + exit 1 +fi + +# LLVM source +LLVM_SRC_DIR="$1" +build_dir="$2" + +if ! [ -f "$LLVM_SRC_DIR/llvm/CMakeLists.txt" ]; then + echo "Expected the path to LLVM to be set correctly (got '$LLVM_SRC_DIR'): can't find CMakeLists.txt" + exit 1 +fi +echo "Using LLVM source dir: $LLVM_SRC_DIR" + +# Setup directories. +echo "Building MLIR in $build_dir" +mkdir -p "$build_dir" + +echo "Beginning build (commands will echo)" +set -x + +cmake -GNinja \ + "-H$LLVM_SRC_DIR/llvm" \ + "-B$build_dir" \ + -DLLVM_INSTALL_UTILS=ON \ + -DLLVM_ENABLE_LLD=ON \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_TARGETS_TO_BUILD="X86;NVPTX;AMDGPU" \ + -DLLVM_INCLUDE_TOOLS=ON \ + -DLLVM_BUILD_TOOLS=OFF \ + -DLLVM_INCLUDE_TESTS=OFF \ + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DLLVM_ENABLE_ASSERTIONS=On + +cmake --build "$build_dir" --target all --target mlir-cpu-runner diff --git a/tensorflow/compiler/mlir/hlo/build_tools/llvm_version.txt b/tensorflow/compiler/mlir/hlo/build_tools/llvm_version.txt new file mode 100644 index 00000000000..0d5446142ec --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/build_tools/llvm_version.txt @@ -0,0 +1,2 @@ + + diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/CMakeLists.txt new file mode 100644 index 00000000000..92759d76383 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/CMakeLists.txt @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(Dialect) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/CMakeLists.txt new file mode 100644 index 00000000000..5ee1a1924ec --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/CMakeLists.txt @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(mhlo) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/CMakeLists.txt new file mode 100644 index 00000000000..e138afa587f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/CMakeLists.txt @@ -0,0 +1,17 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(IR) +add_subdirectory(transforms) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt new file mode 100644 index 00000000000..09bdca84cd3 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/CMakeLists.txt @@ -0,0 +1,31 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Need a separate function because of the .cc vs .cpp used in the one provided by MLIR +function(add_mlir_hlo_dialect dialect dialect_namespace) + set(LLVM_TARGET_DEFINITIONS ${dialect}.td) + mlir_tablegen(${dialect}.h.inc -gen-op-decls) + mlir_tablegen(${dialect}.cc.inc -gen-op-defs) + mlir_tablegen(${dialect}_structs.h.inc -gen-struct-attr-decls) + mlir_tablegen(${dialect}_structs.cc.inc -gen-struct-attr-defs) + add_public_tablegen_target(MLIR${dialect}IncGen) + add_dependencies(mlir-headers MLIR${dialect}IncGen) +endfunction() + +add_mlir_hlo_dialect(chlo_ops chlo) +add_mlir_hlo_dialect(hlo_ops mhlo) +add_mlir_hlo_dialect(lhlo_ops lmhlo) + +add_mlir_interface(infer_fusibility_op_interface) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index 1fbf55ded83..9704f34a4d6 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -17,27 +17,48 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/DialectImplementation.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { namespace chlo { class HloClientDialect : public Dialect { + void initialize(); + public: - explicit HloClientDialect(MLIRContext *context); + explicit HloClientDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, + TypeID::get()) { + initialize(); + } static StringRef getDialectNamespace() { return "chlo"; } }; #define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" + +template +static Value getConstantLike(OpBuilder& b, T constant, Value val) { + Type ty = getElementTypeOrSelf(val.getType()); + + auto getAttr = [&]() -> Attribute { + if (ty.isa()) return b.getIntegerAttr(ty, constant); + if (ty.isa()) return b.getFloatAttr(ty, constant); + llvm_unreachable("unhandled element type"); + }; + // TODO(jpienaar): Add ability to pass loc via native call and update. + return b.create(b.getUnknownLoc(), getAttr(), val); +} } // namespace chlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 79d6fb25318..2f3bbefb5ab 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -33,6 +33,7 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" def HLOClient_Dialect : Dialect { let name = "chlo"; @@ -338,6 +339,49 @@ def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp< let results = (outs HLO_ComplexTensor); } +//===----------------------------------------------------------------------===// +// Unary op +//===----------------------------------------------------------------------===// + +class HLOClient_UnaryElementwiseOp traits, + Type TensorType>: HLOClient_Op { + let arguments = (ins TensorType:$operand); + let results = (outs TensorType); +} + +def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Acos operator"; + + let description = [{ + Returns `Acos(operand)` element-wise. + + $$ + \acos(x) = 2 * \atan(\sqrt(1 - x^2) / (1 + x)) if x != -1 + = pi if x == -1 + $$ + }]; +} + +def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like", + [NoSideEffect, SameOperandsAndResultShape, + InferTypeOpInterface, + DeclareOpInterfaceMethods, + NativeOpTrait<"InferTensorType">]> { + let summary = "Constant like operator"; + + let description = [{ + Returns a splat constant of the same shape as the operand. + }]; + + // TODO(jpienaar): value's type could be tightened. + let arguments = (ins AnyAttr:$value, HLO_Tensor:$operand); + let results = (outs HLO_Tensor); + + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // Broadcasting compare op //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index 4de52639bca..0036cc0dc19 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -19,23 +19,23 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/DialectImplementation.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { class OpBuilder; -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.h.inc" namespace mhlo { @@ -91,7 +91,7 @@ LogicalResult deriveShapeFromFirstOperand( SmallVectorImpl *reifiedReturnShapes); #define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" } // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 0ed4235e23f..d0abbe043ea 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -40,6 +40,14 @@ class HLO_Op traits> : let verifier = [{ return Verify(*this); }]; } +def HLO_LOOP_FUSION : StrEnumAttrCase<"kLoop">; +def HLO_INPUT_FUSION : StrEnumAttrCase<"kInput">; +def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">; +def HLO_CUSTOM_FUSION : StrEnumAttrCase<"kCustom">; +def HLO_FusionKindAttr : StrEnumAttr<"FusionKind", "fusion kind", [ + HLO_LOOP_FUSION, HLO_INPUT_FUSION, HLO_OUTPUT_FUSION, HLO_CUSTOM_FUSION +]>; + //===----------------------------------------------------------------------===// // MHLO nullary op definitions. //===----------------------------------------------------------------------===// @@ -52,15 +60,14 @@ def HLO_ConstOp : HLO_Op<"constant", ); let results = (outs - HLO_Tensor:$output + HLO_StaticShapeTensor:$output ); let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Attribute value" >]; - let printer = [{ return Print(*this, &p); }]; - let parser = [{ return ParseConstOp(&parser, &result); }]; + let assemblyFormat = "attr-dict $value"; let hasFolder = 1; @@ -656,13 +663,14 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO } def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { - let arguments = (ins Variadic:$val); + let arguments = (ins Variadic:$val); let results = (outs HLO_Tuple); let builders = [OpBuilder< "OpBuilder &builder, OperationState &results, " "ValueRange values">]; + let hasCanonicalizer = 1; } def HLO_CompareOp: HLO_Op<"compare", @@ -1067,7 +1075,10 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, HLO_Tensor:$operand, I32Attr:$dimension ); - let results = (outs HLO_IntTensor); + // TODO(hinsu): Allow 64-bit result types once XLA HLO dialect based on the + // XLA semantics is available. This limitation is because of the current XLA + // implementation. + let results = (outs I32Tensor); } def HLO_MapOp: HLO_Op<"map", @@ -1318,13 +1329,14 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> { } //===----------------------------------------------------------------------===// -// MHLO RngUniform Operator. +// MHLO RNG Operators. //===----------------------------------------------------------------------===// + def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { let arguments = (ins HLO_PredIntOrFpTensor:$a, HLO_PredIntOrFpTensor:$b, - I64Tensor:$shape + HLO_DimensionTensor:$shape ); let results = (outs HLO_PredIntOrFpTensor); @@ -1336,7 +1348,7 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { let arguments = (ins HLO_FpTensor:$mu, HLO_FpTensor:$sigma, - I64Tensor:$shape + HLO_DimensionTensor:$shape ); let results = (outs HLO_FpTensor); @@ -1344,6 +1356,19 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { let hasCustomHLOConverter = 1; } +def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, BASE_HLO_RngBitGeneratorOp { + let arguments = (ins + // TODO(jpienaar): This could be an enum instead. + I32Attr:$rng_algorithm, + HLO_IntOrFpTensor:$initial_state + ); + + let results = (outs HLO_TensorOrTuple:$result); + + // TODO(jpienaar): This should not be needed. + let hasCustomHLOConverter = 1; +} + //===----------------------------------------------------------------------===// // MHLO Quantize Operator. //===----------------------------------------------------------------------===// @@ -1375,7 +1400,8 @@ def HLO_FusionOp : HLO_Op<"fusion", []> { let regions = (region SizedRegion<1>:$fused_computation); let arguments = (ins - Variadic:$operands + Variadic:$operands, + OptionalAttr:$fusion_kind ); let results = (outs diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 7f9784d7f11..2f80545ad19 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -316,6 +316,19 @@ class BASE_HLO_RealOp { }]; } +class BASE_HLO_RngBitGeneratorOp { + string summary = "Uniform random number generator operator"; + + string description = [{ + Returns an output with a given shape filled with uniform random bits using + the specified algorithm (or backend default) and returns an updated state + (with the same shape as initial state) and the generated random data. + + See + https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator. + }]; +} + class BASE_HLO_RoundOp { string summary = "Round operator"; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index e1ae9e1fb89..c201aeff8ec 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -27,6 +27,9 @@ def CastIntElementsAttr : NativeCodeCall<"$0.cast()">; class ConstantSplat : NativeCodeCall< "hlo::getSplat(&$_builder, $0, " # value # ")">; +class HLO_ConstantLike : NativeCodeCall< + "chlo::getConstantLike($_builder, " # value # ", $0)">; + def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; def BinBroadcastDimensions : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h index ecbf2e05000..00de1170f8a 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h @@ -21,7 +21,7 @@ limitations under the License. namespace mlir { -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc" } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td index eb2c1ba3ffe..f8e02d413e9 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td @@ -140,7 +140,7 @@ def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { Here the effective workload shape roughly represents the maximum parallelism can be used during the codegen stage. It's used to check the shape-compatibility of the operation. During fusion, we only - try to fuse shape-compatible ops for performace. + try to fuse shape-compatible ops for performance. For example, the effective workload shape of an elementwise op is its output shape, while the effective workload shape of a reduction op may be its operand shape. diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index fd31bec44c0..bb9b29096f3 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -19,21 +19,21 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project -#include "mlir/Interfaces/ViewLikeInterface.h" // from @llvm-project +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" namespace mlir { class OpBuilder; -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.h.inc" namespace lmhlo { @@ -44,7 +44,7 @@ class LmhloDialect : public Dialect { }; #define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" } // namespace lmhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 87082219db7..3fa46584ca2 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -66,6 +66,8 @@ def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; +def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>; + //===----------------------------------------------------------------------===// // LMHLO nullary op definitions. //===----------------------------------------------------------------------===// @@ -467,7 +469,7 @@ def ReshapeMemRefCastOp: Op:$shape + LHLO_ExtentBuffer:$shape ); let results = (outs AnyRankedOrUnrankedMemRef:$result); diff --git a/tensorflow/core/lib/bfloat16/bfloat16.cc b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h similarity index 61% rename from tensorflow/core/lib/bfloat16/bfloat16.cc rename to tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h index e6e24bc0786..5773901ad78 100644 --- a/tensorflow/core/lib/bfloat16/bfloat16.cc +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#ifndef MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_ +#define MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_ -#include "third_party/eigen3/Eigen/Core" +namespace mlir { +namespace mhlo { -namespace tensorflow { +void registerAllDialects(); -const uint16_t bfloat16::NAN_VALUE; -const uint16_t bfloat16::ZERO_VALUE; - -B16_DEVICE_FUNC bfloat16::operator Eigen::half() const { - return static_cast(float(*this)); } -} // end namespace tensorflow +} // namespace mlir + +#endif // MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt new file mode 100644 index 00000000000..6de6851b8d7 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/CMakeLists.txt @@ -0,0 +1,23 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set(LLVM_TARGET_DEFINITIONS mhlo_passes.td) +mlir_tablegen(mhlo_passes.h.inc -gen-pass-decls -name MHLO) +add_public_tablegen_target(MLIRMhloPassIncGen) + +set(LLVM_TARGET_DEFINITIONS lmhlo_passes.td) +mlir_tablegen(lmhlo_passes.h.inc -gen-pass-decls -name LMHLO) +add_public_tablegen_target(MLIRLmhloPassIncGen) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td new file mode 100644 index 00000000000..963ff5dbacf --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.td @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def LhloCopyRemovalPass : Pass<"lhlo-copy-removal", "FuncOp"> { + let summary = "Removes redundant LHLO copy operations."; + let constructor = "createLhloCopyRemovalPass()"; +} + + +def LhloLegalizeToLinalgPass : Pass<"lhlo-legalize-to-linalg", "FuncOp"> { + let summary = "Legalize from LHLO dialect to Linalg dialect."; + let constructor = "createLegalizeLhloToLinalgPass()"; +} + + +def LhloFuseLinalgPass : Pass<"lhlo-fuse-linalg", "FuncOp"> { + let summary = "Greedily fuse linalg ops obtained after LHLO lowering."; + let constructor = "createLhloFuseLinalgPass()"; + let options = [ + Option<"use_parallel_loops_", "use-parallel-loops", "bool", + /*default=*/"false", "Tiles GenericOp consumer to parallel loops before linalg fusion">, + ListOption<"tile_sizes_", "tile-sizes", "unsigned", + "Faster memory space number to promote fusion buffers to", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ]; +} + + +def LhloLegalizeToAffinePass : Pass<"lhlo-legalize-to-affine", "FuncOp"> { + let summary = "Legalize from LHLO dialect to affine dialect."; + let constructor = "createLhloLegalizeToAffinePass()"; +} + + +def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> { + let summary = "Legalize from LHLO dialect to GPU dialect."; + let constructor = "createLegalizeToGpuPass()"; +} + + +def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> { + let summary = "Legalize from LHLO dialect to LLVM."; + let constructor = "createTestLhloToLLVMPass()"; +} + + +def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> { + let summary = "Legalize from LHLO dialect to parallel loops."; + let constructor = "createLegalizeLhloToParallelLoopsPass()"; +} + diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index a0246f93180..c51bcfcfe89 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 5d2bffcec2a..2bb5ab2888d 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -18,10 +18,10 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" namespace mlir { namespace lmhlo { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td new file mode 100644 index 00000000000..fa3bde24df1 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -0,0 +1,108 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def TestChloLegalizeToHloPass : Pass<"mhlo-test-chlo-legalize-to-hlo", "FuncOp"> { + let summary = "Test pass for applying chlo -> hlo legalization patterns."; + let constructor = "createTestChloLegalizeToHloPass()"; +} + +def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { + let summary = "Legalize from HLO dialect to LHLO dialect."; + let constructor = "createLegalizeToLhloPass()"; +} + +def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> { + let summary = "Legalize from MHLO control flow to CFG control flow."; + let constructor = "createLegalizeControlFlowPass()"; +} + +def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> { + let summary = "Legalizes gathers to a torch index select."; + let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; +} + + +def LegalizeTanhToApproximationPass : Pass<"mhlo-legalize-tanh-to-approximation", "FuncOp"> { + let summary = "Legalize tanh from standard dialect to an approximation."; + let constructor = "createLegalizeTanhToApproximationPass()"; +} + + +def HloLegalizeToLinalgPass : Pass<"hlo-legalize-to-linalg", "FuncOp"> { + let summary = "Legalize from HLO dialect to Linalg dialect."; + let constructor = "createLegalizeHloToLinalgPass()"; +} + + +def LegalizeToStandardPass : Pass<"mhlo-legalize-to-std", "FuncOp"> { + let summary = "Legalize from MHLO dialect to standard dialect."; + let constructor = "createLegalizeToStdPass()"; +} + +def LowerComplexPass : Pass<"mhlo-test-lower-complex", "FuncOp"> { + let summary = "Lower complex operations into non-complex operations."; + let constructor = "createLowerComplexPass()"; +} + + +def LegalizeGeneralDotPass : Pass<"mhlo-test-lower-general-dot", "FuncOp"> { + let summary = "Tests lowering general dot to a non-batched dot when possible."; + let constructor = "createLegalizeGeneralDotPass()"; +} + + +def TestMaterializeBroadcastsPass : Pass<"mhlo-test-materialize-broadcasts", "FuncOp"> { + let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; + let constructor = "createTestMaterializeBroadcastsPass()"; +} + + +def MhloFusionPass : Pass<"mhlo-fusion", "FuncOp"> { + let summary = "Fuse mhlo ops to kLoop/kInput fusion patterns."; + let constructor = "createMhloFusionPass()"; +} + + +def OptimizeMhloPass : Pass<"mhlo-test-optimize", "FuncOp"> { + let summary = "Run optional HLO optimizations."; + let constructor = "createOptimizeMhloPass()"; +} + + +def SinkConstantsToControlFlowPass : Pass<"mhlo-sink-constants-to-control-flow", "FuncOp"> { + let summary = "Sink constants implicitly captured in control flow regions. This " + "is necessary to export to XLA."; + let constructor = "createSinkConstantsToControlFlowPass()"; +} + + +def TestInferShapedTypeMethodsPass : Pass<"mhlo-test-infer-shaped-type-methods", "FuncOp"> { + let summary = "Uses test ops to invoke InferShapedTypeOpInterface methods."; + let constructor = "createTestInferShapedTypeMethodsPass()"; +} + + +def TransformUnrankedHloPass : Pass<"transform-unranked-hlo", "FuncOp"> { + let summary = "Realize element-wise operations on ranked tensors where possible."; + let constructor = "createTransformUnrankedHloPass()"; +} + + +def TestUnfuseBatchNormPass : Pass<"mhlo-test-unfuse-batch-norm", "FuncOp"> { + let summary = "Test pass for materializing 'broadcast_dimensions' attributes."; + let constructor = "createTestUnfuseBatchNormPass()"; +} diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 9ea39e95fef..efa116f3f0d 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -23,6 +23,7 @@ limitations under the License. namespace mlir { class FuncOp; +class FunctionPass; class ModuleOp; class Operation; template @@ -58,18 +59,26 @@ std::unique_ptr> createSinkConstantsToControlFlowPass(); // fuse mhlo ops to kLoop/kInput fusion patterns std::unique_ptr> createMhloFusionPass(); +/// Lowers the standard TanhOp to an approximation that does not use intrinsics. +std::unique_ptr> createLegalizeTanhToApproximationPass(); + +std::unique_ptr createOptimizeMhloPass(); +std::unique_ptr createLowerComplexPass(); +std::unique_ptr<::mlir::Pass> createLegalizeGeneralDotPass(); +std::unique_ptr createLegalizeGatherToTorchIndexSelectPass(); + } // namespace mhlo namespace lmhlo { // Lowers from LHLO dialect to Affine dialect. -std::unique_ptr> createLegalizeToAffinePass(); +std::unique_ptr> createLhloLegalizeToAffinePass(); // Lowers from LHLO dialect to Linalg dialect. std::unique_ptr> createLegalizeLhloToLinalgPass(); // Lowers from LHLO dialect to GPU dialect. -std::unique_ptr> createLegalizeToGpuPass(); +std::unique_ptr createLegalizeToGpuPass(); // Fuses linalg ops obtained after LHLO lowering. To enable fusion, // operations are first tiled. @@ -80,7 +89,7 @@ std::unique_ptr> createLegalizeToGpuPass(); // 'tile_sizes' provides the tile sizes to use for tiling. If the linalg // operation has more dimensions than tile sizes provided, 1 is used as // default. -std::unique_ptr> createLhloFuseLinalg( +std::unique_ptr createLhloFuseLinalgPass( bool use_parallel_loops = false, llvm::ArrayRef tile_sizes = {}); // Removes unnecessary LHLO copies which copy from the allocated buffers to the @@ -94,12 +103,6 @@ std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); } // namespace lmhlo -namespace hlo { - -/// Lowers the standard TanhOp to an approximation that does not use intrinsics. -std::unique_ptr> createLegalizeTanhToApproximationPass(); - -} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h new file mode 100644 index 00000000000..8f70f64359b --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_ +#define MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_ + +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace mhlo { + +std::unique_ptr createTestChloLegalizeToHloPass(); +std::unique_ptr createTestInferShapedTypeMethodsPass(); +std::unique_ptr createTestMaterializeBroadcastsPass(); +std::unique_ptr createTestUnfuseBatchNormPass(); + +#define GEN_PASS_REGISTRATION +#include "mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.h.inc" + +inline void registerAllMhloPasses() { registerMHLOPasses(); } + +} // namespace mhlo + +namespace lmhlo { + +std::unique_ptr createTestLhloToLLVMPass(); + +#define GEN_PASS_REGISTRATION +#include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc" + +inline void registerAllLmhloPasses() { registerLMHLOPasses(); } + +} // namespace lmhlo +} // namespace mlir + +#endif // MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REGISTER_PASSES_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index cb9a85a658a..725155e9403 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -18,9 +18,9 @@ limitations under the License. #include -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { class LLVMTypeConverter; @@ -38,6 +38,13 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns, void PopulateComplexLoweringPatterns(MLIRContext *context, OwningRewritePatternList *patterns); +void PopulateOptimizeMHLOPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +// Rewrite patterns for gather to equivalent torch index select legalization. +void PopulateGatherToTorchIndexSelectPatterns( + mlir::MLIRContext *context, OwningRewritePatternList *patterns); + void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); @@ -73,13 +80,17 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context, void PopulateUnfuseBatchNormPatterns(MLIRContext *context, OwningRewritePatternList *patterns); +// Populates a pattern that translates the standard TanhOp to an approximation +// that does not use intrinsics. +void PopulateTanhToApproximationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + } // namespace mhlo namespace lmhlo { /// Collect a set of patterns to convert from the LHLO dialect to LLVM. -void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, - LLVMTypeConverter *converter, +void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter, OwningRewritePatternList *patterns); } // namespace lmhlo @@ -93,14 +104,6 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, } // namespace chlo -namespace hlo { - -// Populates a pattern that translates the standard TanhOp to an approximation -// that does not use intrinsics. -void PopulateTanhToApproximationPatterns(MLIRContext *context, - OwningRewritePatternList *patterns); - -} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h index 3be7d42cc25..1c57073f4ab 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h @@ -19,12 +19,12 @@ limitations under the License. // Utilities relating to implementing HLO broadcasting. // Note: This file should not depend on any non-MLIR TensorFlow libraries. -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LLVM.h" namespace mlir { namespace hlo { @@ -38,10 +38,12 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, // Emits shape dialect ops to compute the result shape for a broadcasting // binary elementwise op which broadcasts according to "numpy" semantics -// (see above), returning an extents tensor of the resulting shape. -Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, - Value rhs, - OpBuilder& builder); +// (see above), returning a `shape.shape` or an extent tensor of the resulting +// shape. The result should only be an extent tensor in contexts that ensure +// both operands to be broadcastable. +Value ComputeBinaryElementwiseBroadcastingResultExtents( + Location loc, Value lhs, Value rhs, OpBuilder& builder, + bool unsafe_as_extent_tensor); } // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h index a63df336d8f..4cf74385843 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CONVERT_OP_FOLDER_H_ -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" +#include "mlir/IR/StandardTypes.h" namespace mlir { namespace hlo { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h index b31ba231acd..1e335ae6b82 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_HLO_UTILS_H_ -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" namespace mlir { namespace hlo { diff --git a/tensorflow/compiler/mlir/hlo/lib/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/CMakeLists.txt new file mode 100644 index 00000000000..ec65a5ee882 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/CMakeLists.txt @@ -0,0 +1,17 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(Dialect) +add_subdirectory(utils) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/CMakeLists.txt new file mode 100644 index 00000000000..5ee1a1924ec --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/CMakeLists.txt @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(mhlo) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/CMakeLists.txt new file mode 100644 index 00000000000..e138afa587f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/CMakeLists.txt @@ -0,0 +1,17 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(IR) +add_subdirectory(transforms) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt new file mode 100644 index 00000000000..d7bb5057b00 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/CMakeLists.txt @@ -0,0 +1,82 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +include_directories(BEFORE + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}) + +set(LLVM_TARGET_DEFINITIONS hlo_patterns.td) +mlir_tablegen(hlo_patterns.cc.inc -gen-rewriters) +add_public_tablegen_target(MLIRMhloRewriterIncGen) + +set(LLVM_TARGET_DEFINITIONS mhlo_canonicalize.td) +mlir_tablegen(mhlo_canonicalize.inc -gen-rewriters) +add_public_tablegen_target(MLIRMhloCanonicalizeIncGen) + +add_mlir_dialect_library(ChloDialect + chlo_ops.cc + + DEPENDS + MLIRchlo_opsIncGen +) +target_link_libraries(ChloDialect PUBLIC MLIRIR) + +add_mlir_library(MhloInferFusibilityOpInterface + infer_fusibility_op_interface.cc + + DEPENDS + MLIRinfer_fusibility_op_interfaceIncGen +) + + +add_mlir_dialect_library(MhloDialect + hlo_ops.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRMhloCanonicalizeIncGen + MLIRMhloRewriterIncGen + MLIRinfer_fusibility_op_interfaceIncGen +) +target_link_libraries(MhloDialect + PUBLIC + MLIRIR + MhloInferFusibilityOpInterface + MLIRMhloUtils +) + + +add_mlir_dialect_library(LmhloDialect + lhlo_ops.cc + + DEPENDS + MLIRlhlo_opsIncGen +) +target_link_libraries(LmhloDialect PUBLIC MLIRIR) + + +add_mlir_dialect_library(MhloRegisterDialects + init.cc +DEPENDS + MLIRchlo_opsIncGen + MLIRhlo_opsIncGen + MLIRlhlo_opsIncGen +) +target_link_libraries(MhloRegisterDialects + PUBLIC + ChloDialect + MhloDialect + LmhloDialect +) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/canonicalize.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_canonicalize.td similarity index 100% rename from tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/canonicalize.td rename to tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_canonicalize.td diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc index c6c193a9d89..b5eacd686bd 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -13,14 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/utils/broadcast_utils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" namespace mlir { namespace chlo { @@ -151,7 +153,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( } Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents( - loc, lhs, rhs, builder); + loc, lhs, rhs, builder, /*unsafe_as_extent_tensor=*/false); if (!computed_shape) return failure(); reifiedReturnShapes.push_back(computed_shape); return success(); @@ -259,18 +261,59 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp); #undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS #undef BROADCAST_BINARY_OP_DEFS +static LogicalResult Verify(ConstantLikeOp op) { + if (op.value().getType() != op.getType().cast().getElementType()) + return op.emitOpError() << "value's type doesn't match element return type"; + return success(); +} + +LogicalResult ConstantLikeOp::inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + ConstantLikeOp::Adaptor op(operands, attributes); + if (failed(op.verify(location.getValue()))) return failure(); + Type element_type = op.value().getType(); + Type operand_type = op.operand().getType(); + if (operand_type.isa()) { + inferedReturnShapes.emplace_back(element_type); + } else { + const auto& shape = operand_type.cast().getShape(); + inferedReturnShapes.emplace_back(shape, element_type); + } + return success(); +} + +struct ConstantLikeToConstant : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConstantLikeOp op, + PatternRewriter& rewriter) const override { + auto op_type = op.operand().getType().cast(); + if (!op_type.hasStaticShape()) return failure(); + auto type = RankedTensorType::get(op_type.getShape(), op.value().getType()); + ElementsAttr attr = DenseElementsAttr::get(type, op.value()); + rewriter.replaceOpWithNewOp(op.getOperation(), attr); + return success(); + } +}; + +void ConstantLikeOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + #define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" //===----------------------------------------------------------------------===// // chlo Dialect Constructor //===----------------------------------------------------------------------===// -HloClientDialect::HloClientDialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { +void HloClientDialect::initialize() { addOperations< #define GET_OP_LIST -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" >(); } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc index f4df946d11a..9d1c354690a 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" // Static initialization for *HLO dialects registration. static mlir::DialectRegistration mhlo_ops; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index cbd478a0283..f5deb94e3a4 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -15,7 +15,7 @@ limitations under the License. // This file defines the operations used in the MHLO dialect. -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include #include @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_set.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -35,31 +34,33 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Matchers.h" // from @llvm-project -#include "mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/IR/OpImplementation.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/InliningUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" +#include "mlir-hlo/utils/convert_op_folder.h" +#include "mlir-hlo/utils/hlo_utils.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/InliningUtils.h" namespace mlir { -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" +#include "hlo_patterns.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_structs.cc.inc" namespace mhlo { Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, @@ -104,44 +105,13 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices, return GetI64ElementsAttr(slice_limits, builder); } -#include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_canonicalize.inc" +#include "mhlo_canonicalize.inc" } // namespace //===----------------------------------------------------------------------===// // ConstOp //===----------------------------------------------------------------------===// -static void Print(ConstOp op, OpAsmPrinter* printer) { - // Print op name. - *printer << op.getOperationName(); - - // Elide attribute value while printing the attribute dictionary. - SmallVector elided_attrs; - elided_attrs.push_back("value"); - printer->printOptionalAttrDict(op.getAttrs(), elided_attrs); - - *printer << ' ' << op.value(); -} - -static ParseResult ParseConstOp(OpAsmParser* parser, OperationState* result) { - if (parser->parseOptionalAttrDict(result->attributes)) return failure(); - - // If colon is not present after attribute dictionary, it should be short form - // and attribute 'value' is outside the dictionary. - if (failed(parser->parseOptionalColon())) { - Attribute value; - if (parser->parseAttribute(value, "value", result->attributes)) - return failure(); - return parser->addTypeToList(value.getType(), result->types); - } - - // Long form should have type of the result after colon. - Type ty; - if (parser->parseType(ty)) return failure(); - result->types.push_back(ty); - return success(); -} - OpFoldResult ConstOp::fold(ArrayRef operands) { assert(operands.empty() && "constant has no operands"); @@ -339,6 +309,33 @@ void DynamicIotaOp::getCanonicalizationPatterns( results.insert(context); } +//===----------------------------------------------------------------------===// +// DynamicUpdateSliceOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DynamicUpdateSliceOp op) { + OperandRange indices = op.start_indices(); + if (indices.size() <= 1) return success(); + + // Note: start_indices is constrained to Variadic, so it + // is OK to cast indices to ShapedType here. + auto idx_tensor = indices.take_front().front().getType().cast(); + Type first_elem_ty = idx_tensor.getElementType(); + Type elem_ty; + + for (auto idx : llvm::drop_begin(indices, 1)) { + idx_tensor = idx.getType().cast(); + elem_ty = idx_tensor.getElementType(); + + if (first_elem_ty != elem_ty) { + return op.emitOpError() << "start indices must have same element type " + "(encountered mismatch: " + << first_elem_ty << " vs " << elem_ty << ")"; + } + } + return success(); +} + //===----------------------------------------------------------------------===// // AbsOp //===----------------------------------------------------------------------===// @@ -373,8 +370,8 @@ static LogicalResult Verify(CollectivePermuteOp op) { << "expect source_target_pairs attribute of shape (N, 2), but got (" << type.getShape() << ")"; // Check source target pairs for duplicate sources or targets - absl::flat_hash_set sources; - absl::flat_hash_set targets; + llvm::DenseSet sources; + llvm::DenseSet targets; for (auto i = op.source_target_pairs().begin(), e = op.source_target_pairs().end(); i != e; ++i) { @@ -505,6 +502,46 @@ static LogicalResult Verify(TupleOp op) { return success(); } +namespace { + +// Pattern for unpacking and repacking the same tuple. +struct UnpackRepackSameTuple : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TupleOp op, + PatternRewriter& rewriter) const override { + if (op.val().empty()) return failure(); + + Value first_element = op.val().front(); + auto first_element_op = + dyn_cast_or_null(first_element.getDefiningOp()); + if (!first_element_op || first_element_op.indexAttr().getInt() != 0) + return failure(); + + Value tuple_predecessor = first_element_op.getOperand(); + if (tuple_predecessor.getType() != op.getType()) return failure(); + + for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) { + auto element_op = dyn_cast_or_null( + element_and_idx.value().getDefiningOp()); + if (!element_op || + element_op.indexAttr().getInt() != element_and_idx.index() + 1 || + element_op.getOperand() != tuple_predecessor) + return failure(); + } + + rewriter.replaceOp(op, tuple_predecessor); + return success(); + } +}; + +} // namespace + +void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // AllToAllOp //===----------------------------------------------------------------------===// @@ -707,10 +744,12 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) { auto dimSize = operandType.getDimSize(i); auto resultDimSize = resultType.getDimSize(dimIndex); - if (dimSize != 1 && dimSize != resultDimSize) { + // Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we + // add a manual check for this. + if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) { return op.emitOpError( - llvm::formatv("size of operand dimension {0} ({1}) is not equal to " - "1 or size of result dimension {2} ({3})", + llvm::formatv("size of operand dimension {0} ({1}) is not compatible " + "with size of result dimension {2} ({3})", i, dimSize, dimIndex, resultDimSize)); } } @@ -744,7 +783,9 @@ class DynamicBroadcastInDimOpNotActuallyDynamic void DynamicBroadcastInDimOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { - results.insert(context); + results.insert( + context); } //===----------------------------------------------------------------------===// @@ -1465,7 +1506,7 @@ static LogicalResult Verify(PadOp op) { static LogicalResult Verify(ReshapeOp op) { // If the operand type is dynamically shaped there is nothing to verify. - auto operand_ty = op.operand().getType().cast(); + auto operand_ty = op.operand().getType().dyn_cast(); if (!operand_ty || !operand_ty.hasStaticShape()) return success(); // If the operand type is statically shaped (not required) the number of @@ -2119,7 +2160,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, } #define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" //===----------------------------------------------------------------------===// // mhlo Dialect Interfaces @@ -2147,10 +2188,10 @@ struct HLOInlinerInterface : public DialectInlinerInterface { //===----------------------------------------------------------------------===// MhloDialect::MhloDialect(MLIRContext* context) - : Dialect(getDialectNamespace(), context) { + : Dialect(getDialectNamespace(), context, TypeID::get()) { addOperations< #define GET_OP_LIST -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" >(); addInterfaces(); addTypes(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td new file mode 100644 index 00000000000..b8b6cb80fba --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Canonicalization patterns for the MHLO dialect. + +include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" + +def EqualBinaryOperands : Constraint>; + +// Canonicalization patterns. + +def DynamicBroadcastToOwnShape_1 : Pat< + (HLO_DynamicBroadcastInDimOp:$op $arg0, + (Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr), + (replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; +def DynamicBroadcastToOwnShape_2 : Pat< + (HLO_DynamicBroadcastInDimOp:$op $arg0, (Shape_ShapeOfOp $arg1), $attr), + (replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; + diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc index eaa3414b36a..e93a6cfce3d 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" namespace mlir { -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cpp.inc" } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc new file mode 100644 index 00000000000..9fffeae1cc5 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/init.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/register.h" + +// Static initialization for *HLO dialects registration. + +void mlir::mhlo::registerAllDialects() { + static bool init_once = []() { + registerDialect(); + registerDialect(); + registerDialect(); + return true; + }(); + (void)init_once; + + // Dependent dialects +} diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc index bd0dc224ccc..f61a66397e7 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -15,7 +15,7 @@ limitations under the License. // This file defines the operations used in the LMHLO dialect. -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include #include @@ -28,31 +28,31 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/IR/OpImplementation.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" namespace mlir { -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops_structs.cc.inc" namespace lmhlo { LmhloDialect::LmhloDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { + : Dialect(getDialectNamespace(), context, TypeID::get()) { addOperations< #define GET_OP_LIST -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" >(); } @@ -127,7 +127,7 @@ static LogicalResult Verify(ReshapeMemRefCastOp op) { } #define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" // TODO(cheshire): Support folding, reuse code from hlo_ops.cc. diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/mhlo_canonicalize.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/mhlo_canonicalize.td new file mode 100644 index 00000000000..eb92d9e0e46 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/mhlo_canonicalize.td @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the canonicalize pattern definition file. + +include "mlir/IR/OpBase.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" + +def UnaryToBinaryEinsumEq : NativeCodeCall< + "$_builder.getStringAttr(\",\" + $0.getValue().str())">; + +// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first +// operand. +def UnaryEinsumToEinsum : Pat< + (HLO_UnaryEinsumOp $operand, $equation), + (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)), + $operand, (UnaryToBinaryEinsumEq $equation))>; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt new file mode 100644 index 00000000000..bb9f98d32d3 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -0,0 +1,155 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +include_directories(BEFORE + ${CMAKE_CURRENT_BINARY_DIR} + ${CMAKE_CURRENT_SOURCE_DIR}) + +set(LLVM_TARGET_DEFINITIONS lower_complex_patterns.td) +mlir_tablegen(generated_lower_complex.inc -gen-rewriters) +add_public_tablegen_target(MLIRMhloLowerComplexIncGen) + +set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td) +mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters) +add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen) + + +add_mlir_library(ChloPasses + chlo_legalize_to_hlo.cc + chlo_legalize_to_hlo_pass.cc + + DEPENDS + MLIRhlo_opsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + ChloDialect + MLIRIR + MLIRPass +) + +add_mlir_library(MhloPasses + legalize_gather_to_torch_index_select.cc + legalize_tanh_to_approximation.cc + lower_complex.cc + lower_complex_patterns.td + lower_general_dot.cc + materialize_broadcasts.cc + materialize_broadcasts_pass.cc + mhlo_fusion.cc + optimize_mhlo.cc + optimize_mhlo_pass.cc + sink_constants_to_control_flow.cc + test_infer_shaped_type_pass.cc + transform_unranked_hlo.cc + unfuse_batch_norm.cc + unfuse_batch_norm_pass.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRMhloLowerComplexIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRMhloUtils + MLIRPass + MLIRTransformUtils +) + +add_mlir_library(MhloToLhloConversion + hlo_legalize_to_lhlo.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRlhlo_opsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MhloDialect + LmhloDialect + MLIRIR + MLIRPass +) + +add_mlir_library(MhloToStandard + legalize_control_flow.cc + legalize_to_standard.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRlhlo_opsIncGen + MLIRMhloLegalizeToStandardIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass +) + +add_mlir_library(MhloLhloToLinalg + legalize_to_linalg.cc + + DEPENDS + MLIRhlo_opsIncGen + MLIRlhlo_opsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MhloDialect + MLIRIR + MLIRPass +) + +add_mlir_library(LmhloPasses + lhlo_copy_removal.cc + lhlo_fuse_linalg.cc + lhlo_legalize_to_affine.cc + lhlo_legalize_to_gpu.cc + lhlo_legalize_to_llvm.cc + lhlo_legalize_to_llvm_pass.cc + lhlo_legalize_to_parallel_loops.cc + + DEPENDS + MLIRlhlo_opsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + LmhloDialect + MLIRIR + MLIRPass +) + +add_library(AllMhloPasses INTERFACE) +target_link_libraries(AllMhloPasses INTERFACE + ChloPasses + MhloPasses + MhloToLhloConversion + MhloToStandard + MhloLhloToLinalg + LmhloPasses +) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index 06e95e04c76..c2db4880632 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -13,20 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/utils/broadcast_utils.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace chlo { - namespace { // Converts binary ops that statically are determined to not broadcast directly @@ -74,10 +76,6 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern { // - Legal combinations of degenerate (1-dim) implicit broadcasting. // The restriction on broadcast_dims derives from the definition of the // `shape.broadcast` op, which only supports prefix-padding. -// -// It may be possible to expand this pattern to operate on unranked tensors in -// the future by emitting more code to dynamically differentiate based on rank. -// Whether that is of any practical benefit remains to be seen. template struct ConvertRankedDynamicBroadcastBinaryOp : public OpRewritePattern { @@ -126,8 +124,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); Value result_extents = - hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, - rewriter); + hlo::ComputeBinaryElementwiseBroadcastingResultExtents( + loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true); // Note that we unconditionally emit DynamicBroadcastInDim ops and let // downstream canonicalizations fold them away if possible. This is @@ -160,6 +158,273 @@ struct ConvertRankedDynamicBroadcastBinaryOp } }; +// Converts a broadcasting binary operation with a scalar operand and an +// unranked operand to a ranked broadcasting operation by dynamically reshaping +// the unranked operand to a 1D tensor. This will always be safe because +// broadcasting from a scalar to another shape always works. +template +struct ConvertUnrankedScalarDynamicBroadcastBinaryOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value lhs = op.lhs(); + Value rhs = op.rhs(); + + auto lhs_ranked_type = lhs.getType().dyn_cast(); + auto lhs_unranked_type = lhs.getType().dyn_cast(); + + auto rhs_ranked_type = rhs.getType().dyn_cast(); + auto rhs_unranked_type = rhs.getType().dyn_cast(); + + bool lhs_is_scalar = lhs_ranked_type && + lhs_ranked_type.getShape().empty() && + rhs_unranked_type; + bool rhs_is_scalar = rhs_ranked_type && + rhs_ranked_type.getShape().empty() && + lhs_unranked_type; + + // Only support the case where exactly one operand is scalar and the other + // is unranked. Other patterns in this file will create more efficient + // lowerings for cases where both ranks are known or will handle the more + // generic case of both inputs being unranked. + if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure(); + + auto result_type = op.getResult().getType().template dyn_cast(); + + // Reshape the non-scalar value into a dynamically sized, rank-1 tensor + Value shape = + rewriter.create(loc, lhs_is_scalar ? rhs : lhs); + Value num_elements = rewriter.create(loc, shape); + Value size_tensor = + rewriter.create(loc, num_elements); + Value reshaped = rewriter.create( + loc, RankedTensorType::get({-1}, result_type.getElementType()), + lhs_is_scalar ? rhs : lhs, size_tensor); + + // Create a new ranked Chlo op that will be further lowered by other + // patterns into Mhlo. + SmallVector operands{lhs_is_scalar ? lhs : reshaped, + rhs_is_scalar ? rhs : reshaped}; + Value computed = rewriter.create( + loc, SmallVector{reshaped.getType()}, operands, op.getAttrs()); + + // Reshape the result back into an unranked tensor. + rewriter.replaceOpWithNewOp(op, result_type, + computed, shape); + + return success(); + } +}; + +// Handles lowering of the following pattern to patterns that will be further +// matched by other patterns until they result in LHLO: +// %result = "chlo.op"(%lhs, %rhs) : (<*xTy>, <*xTy>) -> <*xTy> +// +// The sequence of specializations this handles is: +// - Either operand being scalar +// - Operands having equal shapes +// - The resulting value being any of ranks [2,6] +template +struct ConvertUnrankedDynamicBroadcastBinaryOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value lhs = op.lhs(); + Value rhs = op.rhs(); + auto lhs_type = lhs.getType().dyn_cast(); + auto rhs_type = rhs.getType().dyn_cast(); + auto result_type = op.getResult().getType().template dyn_cast(); + + // Only support unranked operands. If either operand is ranked, another + // pattern will handle the lowering. + if (!lhs_type || !rhs_type) return failure(); + + // If lhs is scalar + auto if_op = rewriter.create( + loc, result_type, IsScalarTensor(rewriter, op, lhs), true); + OpBuilder if_lhs_scalar_builder = if_op.getThenBodyBuilder(); + Value reshaped_lhs = if_lhs_scalar_builder.create( + loc, RankedTensorType::get({}, lhs_type.getElementType()), lhs); + Value if_lhs_scalar_result = if_lhs_scalar_builder.create( + loc, ArrayRef{result_type}, ArrayRef{reshaped_lhs, rhs}, + op.getAttrs()); + if_lhs_scalar_builder.create(loc, if_lhs_scalar_result); + + // If lhs is NOT scalar + // + // See if rhs is scalar + OpBuilder else_lhs_scalar_builder = if_op.getElseBodyBuilder(); + auto if_rhs_scalar_op = else_lhs_scalar_builder.create( + loc, result_type, IsScalarTensor(else_lhs_scalar_builder, op, rhs), + true); + else_lhs_scalar_builder.create(loc, + if_rhs_scalar_op.getResult(0)); + OpBuilder if_rhs_scalar_builder = if_rhs_scalar_op.getThenBodyBuilder(); + Value reshaped_rhs = if_rhs_scalar_builder.create( + loc, RankedTensorType::get({}, lhs_type.getElementType()), rhs); + Value if_rhs_scalar_result = if_rhs_scalar_builder.create( + loc, ArrayRef{result_type}, ArrayRef{lhs, reshaped_rhs}, + op.getAttrs()); + if_rhs_scalar_builder.create(loc, if_rhs_scalar_result); + + // If NEITHER shape is scalar + // + // See if shapes are equal. + OpBuilder else_no_scalars_builder = if_rhs_scalar_op.getElseBodyBuilder(); + Value shape_of_lhs = + else_no_scalars_builder.create(loc, lhs); + Value shape_of_rhs = + else_no_scalars_builder.create(loc, rhs); + Value equal_shapes = else_no_scalars_builder.create( + loc, shape_of_lhs, shape_of_rhs); + + auto if_eq_shapes_op = else_no_scalars_builder.create( + loc, result_type, equal_shapes, true); + else_no_scalars_builder.create(loc, + if_eq_shapes_op.getResult(0)); + + OpBuilder if_eq_shapes_builder = if_eq_shapes_op.getThenBodyBuilder(); + Value non_broadcast_op = + Adaptor::CreateOp(op, result_type, lhs, rhs, if_eq_shapes_builder); + if_eq_shapes_builder.create(loc, non_broadcast_op); + + // If shapes are not scalar, nor equal + // + // See if values are of a rank that we support. + OpBuilder if_neq_shapes_builder = if_eq_shapes_op.getElseBodyBuilder(); + if_neq_shapes_builder.create( + loc, HandleBroadcastAndOp(if_neq_shapes_builder, op, lhs, rhs)); + + rewriter.replaceOp(op, {if_op.getResult(0)}); + return success(); + } + + private: + // Returns the dyanamic result of checking the given value is a scalar + // tensor. + Value IsScalarTensor(OpBuilder &rewriter, ChloOpTy op, Value tensor) const { + auto loc = op.getLoc(); + + Value shape_of_tensor = rewriter.create(loc, tensor); + Value rank_tensor = rewriter.create( + loc, rewriter.getIndexType(), shape_of_tensor); + return rewriter.create(loc, rewriter.getI1Type(), CmpIPredicate::eq, + rank_tensor, + rewriter.create(loc, 0)); + } + + // Create the if statement and code for a broadcasting op with a result of a + // given rank. + scf::IfOp createRankSpecializedBroadcastAndOp(OpBuilder &builder, ChloOpTy op, + Value lhs, Value rhs, + Value actual_rank, + int targeted_rank) const { + auto loc = op.getLoc(); + + // Create the if block to place the current specialized logic in. + Value greater_rank_is_n = builder.create( + loc, CmpIPredicate::eq, actual_rank, + builder.create(loc, targeted_rank)); + auto if_op = + builder.create(loc, lhs.getType(), greater_rank_is_n, true); + OpBuilder if_builder = if_op.getThenBodyBuilder(); + + // Handle shape broadcasting and inferrence. + Value lhs_shape = if_builder.create(loc, lhs); + Value rhs_shape = if_builder.create(loc, rhs); + SmallVector ranked_shape(targeted_rank, 1); + auto extent_tensor_type = + RankedTensorType::get({targeted_rank}, builder.getIndexType()); + auto reshaped_type = RankedTensorType::get( + llvm::SmallVector(targeted_rank, + RankedTensorType::kDynamicSize), + lhs.getType().template dyn_cast().getElementType()); + Value ranked_shape_val = if_builder.create( + loc, extent_tensor_type, + mlir::DenseIntElementsAttr::get(extent_tensor_type, ranked_shape)); + // TODO(tpopp): Return extent tensors when possible to signal that this is a + // guaranteed safe broadcast by construction. + Value extended_lhs = if_builder.create( + loc, extent_tensor_type, lhs_shape, ranked_shape_val, nullptr); + Value extended_rhs = if_builder.create( + loc, extent_tensor_type, rhs_shape, ranked_shape_val, nullptr); + + // 1. Reshape operands to the given rank (with the same number of elements) + // 2. Compute the ranked-broadcasted ChloOp (which will assert that the ops + // can be broadcasted and do the actual broadcasting) + // 3. Type erase the output back to unranked + Value reshaped_lhs = if_builder.create( + loc, reshaped_type, lhs, extended_lhs); + Value reshaped_rhs = if_builder.create( + loc, reshaped_type, rhs, extended_rhs); + Value result = if_builder.create( + loc, ArrayRef{reshaped_type}, + ArrayRef{reshaped_lhs, reshaped_rhs}, op.getAttrs()); + Value reshaped_result = if_builder.create( + loc, UnrankedTensorType::get(reshaped_type.getElementType()), result); + if_builder.create(loc, reshaped_result); + + // Return the if_op, so the result can be used and the else block can be + // used for the next rank specialized step. + return if_op; + } + + // Iterates over the desired ranks to be specialized and generates the code + // snippet for each case. + Value HandleBroadcastAndOp(OpBuilder &rewriter, ChloOpTy op, Value lhs, + Value rhs) const { + constexpr int max_rank_specialization = 7; + auto loc = op.getLoc(); + + // Find the larger rank of the 2 operands. + auto extent_tensor_type = RankedTensorType::get({ShapedType::kDynamicSize}, + rewriter.getIndexType()); + Value lhs_shape = + rewriter.create(loc, extent_tensor_type, lhs); + Value rhs_shape = + rewriter.create(loc, extent_tensor_type, rhs); + Value lhs_rank = + rewriter.create(loc, rewriter.getIndexType(), lhs_shape); + Value rhs_rank = + rewriter.create(loc, rewriter.getIndexType(), rhs_shape); + Value greater_rank_lhs = + rewriter.create(loc, CmpIPredicate::sgt, lhs_rank, rhs_rank); + Value greater_rank = + rewriter.create(loc, greater_rank_lhs, lhs_rank, rhs_rank); + + // Generate a list of nested if/else statements to handle rank + // specializations from 2-6. + scf::IfOp if_op = createRankSpecializedBroadcastAndOp(rewriter, op, lhs, + rhs, greater_rank, 2); + + // Put each subsequent rank specialization inside the else statement of the + // previous one. + OpBuilder else_builder = if_op.getElseBodyBuilder(); + for (int i = 3; i < max_rank_specialization; i++) { + auto inner_if = createRankSpecializedBroadcastAndOp(else_builder, op, lhs, + rhs, greater_rank, i); + + else_builder.create(loc, inner_if.getResult(0)); + else_builder = inner_if.getElseBodyBuilder(); + } + + // Fire an assertion if none of the rank specializations applied (one of the + // ranks was greater than 6). + else_builder.create( + loc, else_builder.create(loc, 0, 1), + "Input for dynamic binary op lowering was of a rank greater than 6"); + else_builder.create(loc, lhs); + + // Return the result of the outermost if statement. + return if_op.getResult(0); + } +}; + template void PopulateForBinaryOp(MLIRContext *context, OwningRewritePatternList *patterns) { @@ -169,6 +434,10 @@ void PopulateForBinaryOp(MLIRContext *context, patterns->insert< ConvertRankedDynamicBroadcastBinaryOp>( context, 5); + patterns->insert< + ConvertUnrankedScalarDynamicBroadcastBinaryOp, + ConvertUnrankedDynamicBroadcastBinaryOp>( + context); } template diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index 48749c7d43d..50cd6df5c99 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -13,15 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" namespace mlir { -namespace chlo { +namespace mhlo { namespace { @@ -31,14 +33,15 @@ struct TestChloLegalizeToHloPass ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; - conversionTarget.addIllegalDialect(); + conversionTarget.addIllegalDialect(); // Consider the mhlo dialect legal for tests. conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); + conversionTarget.addLegalDialect(); - PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); + chlo::PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); if (failed(applyPartialConversion(getFunction(), conversionTarget, conversionPatterns))) { @@ -49,9 +52,10 @@ struct TestChloLegalizeToHloPass } // namespace -} // namespace chlo +std::unique_ptr createTestChloLegalizeToHloPass() { + return std::make_unique(); +} + +} // namespace mhlo } // namespace mlir -static mlir::PassRegistration pass( - "mhlo-test-chlo-legalize-to-hlo", - "Test pass for applying chlo -> hlo legalization patterns"); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 4ee45d56a8e..a8c3ad17ebb 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -15,26 +15,25 @@ limitations under the License. // This file implements logic for lowering HLO dialect to LHLO dialect. -#include "absl/memory/memory.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/AffineMap.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/BufferPlacement.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { @@ -42,9 +41,6 @@ namespace { template using BaseOpConversion = BufferAssignmentOpConversionPattern; -using StdReturnOpConverter = - detail::BufferAssignmentReturnOpConverter; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -272,27 +268,21 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { // Copy over the operations inside the region. rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); - // Create new block arguments with correct type. + // Convert the region signature to memref and add extra result. auto& entry_block = new_op.body().front(); - int original_arg_count = entry_block.getNumArguments(); - for (int i = 0; i < original_arg_count; ++i) { - auto old_arg = entry_block.getArgument(i); - auto old_type = old_arg.getType().cast(); + TypeConverter::SignatureConversion sig_conversion( + entry_block.getNumArguments() + 1); + for (auto arg : entry_block.getArguments()) { + auto old_type = arg.getType().cast(); auto new_type = MemRefType::get(old_type.getShape(), old_type.getElementType()); - auto new_arg = entry_block.addArgument(new_type); - rewriter.replaceUsesOfBlockArgument(old_arg, new_arg); + sig_conversion.addInputs(arg.getArgNumber(), new_type); } - // Add an argument for the result. - entry_block.addArgument( - entry_block.getArgument(original_arg_count).getType()); - // Remove the old arguments. - for (int i = original_arg_count - 1; i >= 0; --i) { - entry_block.eraseArgument(i); - } - // Insert terminator at the end. - rewriter.setInsertionPointToEnd(&entry_block); - rewriter.create(loc); + auto return_op = cast(entry_block.getTerminator()); + auto result_type = return_op.results().front().getType().cast(); + sig_conversion.addInputs({MemRefType::get(result_type.getShape(), + result_type.getElementType())}); + rewriter.applySignatureConversion(&new_op.body(), sig_conversion); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); @@ -300,6 +290,12 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { } }; +// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality +// is provided by mlir buffer assignment, so use the pattern from there. +// TODO(DFKI): Move this out of detail. +using HloToLhloReturnOpConverter = detail::BufferAssignmentReturnOpConverter< + mhlo::ReturnOp, lmhlo::TerminatorOp, lmhlo::CopyOp, false>; + class HloToLhloTensorLoadOpConverter : public BaseOpConversion { public: @@ -312,7 +308,6 @@ class HloToLhloTensorLoadOpConverter } }; -// TODO(b/137624192): Rewrite into a copy and elide copy if possible. class HloToLhloTensorStoreOpConverter : public BaseOpConversion { public: @@ -506,6 +501,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloReduceOpConverter, + HloToLhloReturnOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter >(context, bufferAssignment, converter); @@ -514,11 +510,8 @@ void populateHLOToLHLOConversionPattern( std::unique_ptr> createLegalizeToLhloPass( bool results_escape_function) { - return absl::make_unique(results_escape_function); + return std::make_unique(results_escape_function); } -static PassRegistration legalize_pass( - "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect"); - } // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc index 440df7ec23f..b6e23a6b131 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -18,27 +18,27 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Block.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LogicalResult.h" using mlir::PassRegistration; namespace mlir { namespace mhlo { namespace { -struct LegalizeControlFlow - : public mlir::PassWrapper { +struct LegalizeControlFlowPass + : public mlir::PassWrapper { // Perform the lowering to MLIR control flow. void runOnFunction() override; }; @@ -206,7 +206,7 @@ LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { return success(); } -void LegalizeControlFlow::runOnFunction() { +void LegalizeControlFlowPass::runOnFunction() { auto func = getFunction(); llvm::SmallVector if_ops; func.walk([&](IfOp op) { if_ops.push_back(op); }); @@ -228,9 +228,5 @@ void LegalizeControlFlow::runOnFunction() { std::unique_ptr> mlir::mhlo::createLegalizeControlFlowPass() { - return std::make_unique(); + return std::make_unique(); } - -static PassRegistration legalize_cf_pass( - "mhlo-legalize-control-flow", - "Legalize from MHLO control flow to CFG control flow"); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc new file mode 100644 index 00000000000..59cd3381133 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc @@ -0,0 +1,151 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { + +namespace mhlo { +namespace { + +struct GatherIsTorchIndexSelect : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherOp gather, + PatternRewriter &rewriter) const override { + auto start_indices = gather.start_indices(); + auto start_indices_ty = start_indices.getType().cast(); + if (!start_indices_ty.hasRank()) { + return failure(); + } + + auto operand = gather.operand(); + auto operand_ty = operand.getType().cast(); + if (!operand_ty.hasRank()) { + return failure(); + } + + int64_t index_vector_dim = + std::max(0, start_indices_ty.getRank() - 1); + + // We can use torch_index_select if the last dimension represents the + // gather indices. + auto dimension_numbers = gather.dimension_numbers(); + if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != + index_vector_dim) { + return failure(); + } + + // Index select only works across a single dimension. + if (!start_indices_ty.getShape().empty() && + start_indices_ty.getShape().back() != 1) { + return failure(); + } + + // Only support the default case for start_index_map. + if (dimension_numbers.start_index_map().getType().getRank() != 1 || + dimension_numbers.start_index_map() + .getValue(0) + .cast() + .getValue() != 0) { + return failure(); + } + + auto result_ty = gather.getResult().getType().dyn_cast(); + if (!result_ty) { + return failure(); + } + + // Offset dimensions should be the defaults. + if (dimension_numbers.offset_dims().getType().getNumElements() != + result_ty.getRank() - index_vector_dim) { + return failure(); + } + + for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) { + if ((it.index() + index_vector_dim) != it.value()) { + return failure(); + } + } + + for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) { + // First shape value must be 1. + if (it.index() == 0) { + if (it.value().getSExtValue() != 1) { + return failure(); + } + continue; + } + + // The gather needs to index the entire slice for each other dimension. + if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) { + return failure(); + } + } + + llvm::SmallVector index_select_shape = + llvm::to_vector<4>(start_indices_ty.getShape()); + + for (auto dim : operand_ty.getShape().drop_front()) { + index_select_shape.push_back(dim); + } + + if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() || + dimension_numbers.collapsed_slice_dims().getType().getNumElements() != + 1 || + dimension_numbers.collapsed_slice_dims().getValue({0}) != 0) { + return failure(); + } + + auto torch_index_select = rewriter.create( + gather.getLoc(), + RankedTensorType::get(index_select_shape, operand_ty.getElementType()), + operand, gather.start_indices(), rewriter.getI64IntegerAttr(0), + rewriter.getI64IntegerAttr(0)); + + rewriter.replaceOpWithNewOp(gather, gather.getType(), + torch_index_select); + + return success(); + } +}; + +struct LegalizeGatherToTorchIndexSelectPass + : public PassWrapper { + /// Perform the lowering of standard dialect operations to approximations. + void runOnFunction() override { + OwningRewritePatternList patterns; + PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; +} // namespace + +void PopulateGatherToTorchIndexSelectPatterns( + mlir::MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert(context); +} + +std::unique_ptr createLegalizeGatherToTorchIndexSelectPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc index 1890646160e..57c494f536b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc @@ -16,15 +16,15 @@ limitations under the License. // This file implements logic for lowering the tanh standard ops to an // approximation. -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" namespace mlir { -namespace hlo { +namespace mhlo { namespace { /// Emits the fast tanh approximation that is also used by XLA. @@ -126,8 +126,8 @@ class ApproximateTanhLowering : public OpRewritePattern { } }; -struct LegalizeTanhToApproximation - : public PassWrapper { +struct LegalizeTanhToApproximationPass + : public PassWrapper { /// Perform the lowering of standard dialect operations to approximations. void runOnFunction() override { OwningRewritePatternList patterns; @@ -140,7 +140,7 @@ struct LegalizeTanhToApproximation std::unique_ptr> createLegalizeTanhToApproximationPass() { - return std::make_unique(); + return std::make_unique(); } void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, @@ -148,9 +148,5 @@ void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, patterns->insert(context); } -static PassRegistration legalize_pass( - "mhlo-legalize-tanh-to-approximation", - "Legalize tanh from standard dialect to an approximation"); - -} // namespace hlo +} // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 717e9682436..f47f2c2fbdc 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -15,26 +15,25 @@ limitations under the License. // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. -#include "absl/memory/memory.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/AffineExpr.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace { @@ -298,8 +297,8 @@ class DataMovementOpConverter : public OpConversionPattern { auto nloops = resultType.getRank(); auto loc = op.getLoc(); auto linalgOp = rewriter.create( - loc, isLHLO ? ArrayRef{} : resultType, args, /*inputCount=*/1, - /*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops), + loc, isLHLO ? ArrayRef{} : resultType, args, /*argsIn=*/1, + /*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(loc, *args.begin()); }); @@ -420,7 +419,7 @@ class LhloBroadcastInDimConverter rewriter.create(loc, operand, llvm::makeArrayRef({zero})); rewriter.create( loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()), - /*inputCount=*/0, /*outputCount=*/1, + /*argsIn=*/0, /*argsOut=*/1, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { @@ -433,7 +432,7 @@ class LhloBroadcastInDimConverter rewriter.create( loc, llvm::None, llvm::makeArrayRef({operand, operand_adaptor.output()}), - /*inputCount=*/1, /*outputCount=*/1, indexing_maps, + /*argsIn=*/1, /*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(loc, *args.begin()); @@ -640,25 +639,25 @@ class ReshapeOpConverter : public OpConversionPattern { } }; -class IotaConverter : public OpConversionPattern { +template +class IotaConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - lmhlo::IotaOp iotaOp, ArrayRef args, + OpTy iotaOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - auto resultMemrefType = - iotaOp.getOperand().getType().dyn_cast(); - if (!resultMemrefType) return failure(); + ShapedType resultShapedType = getHloOpResultType(iotaOp); + if (!resultShapedType) return failure(); - auto resultElementType = resultMemrefType.getElementType(); + auto resultElementType = resultShapedType.getElementType(); if (!resultElementType.isSignlessIntOrFloat()) return failure(); // Construct the indexing maps needed for linalg.generic ops. - unsigned nloops = resultMemrefType.getRank(); + unsigned nloops = resultShapedType.getRank(); - rewriter.create( - iotaOp.getLoc(), ArrayRef{}, args, + auto linalgOp = rewriter.create( + iotaOp.getLoc(), isLHLO ? ArrayRef{} : resultShapedType, args, 0, // args_in 1, // args_out llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), @@ -669,14 +668,16 @@ class IotaConverter : public OpConversionPattern { nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()], nestedBuilder.getIntegerType( resultElementType.getIntOrFloatBitWidth())); - if (resultElementType.isa()) { + if (resultElementType.template isa()) { castOp = nestedBuilder.create(nestedLoc, castOp, resultElementType); } nestedBuilder.create(nestedLoc, castOp); }); - - rewriter.replaceOp(iotaOp, llvm::None); + if (isLHLO) + rewriter.replaceOp(iotaOp, llvm::None); + else + rewriter.replaceOp(iotaOp, linalgOp.output_tensors()); return success(); } }; @@ -768,7 +769,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, patterns->insert, ConstConverter, ConvToLinalgConverter, - IotaConverter, + IotaConverter, LhloBroadcastInDimConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -824,8 +825,8 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, // indexing_maps = [#map0, #map0, #map0], // iterator_types = ["parallel", "parallel"], // } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -struct LhloLegalizeToLinalg - : public PassWrapper { +struct LhloLegalizeToLinalgPass + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -840,8 +841,8 @@ struct LhloLegalizeToLinalg } }; -struct HloLegalizeToLinalg - : public PassWrapper { +struct HloLegalizeToLinalgPass + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -859,54 +860,49 @@ struct HloLegalizeToLinalg namespace lmhlo { std::unique_ptr> createLegalizeLhloToLinalgPass() { - return absl::make_unique(); + return std::make_unique(); } - -static PassRegistration legalize_lhlo_pass( - "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect"); } // namespace lmhlo namespace mhlo { void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { - patterns->insert, - HloBroadcastInDimConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - ReshapeOpConverter, - ReverseConverter, - TransposeConverter>(context); + patterns + ->insert, + HloBroadcastInDimConverter, IotaConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + ReshapeOpConverter, + ReverseConverter, + TransposeConverter>(context); } std::unique_ptr> createLegalizeHloToLinalgPass() { - return absl::make_unique(); + return std::make_unique(); } - -static PassRegistration legalize_hlo_pass( - "hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect"); } // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc index c71aa1d0460..cc574e008d5 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard.cc @@ -16,17 +16,17 @@ limitations under the License. // This file implements logic for lowering MHLO dialect to Standard dialect. #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" namespace mlir { namespace { -#include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc" +#include "generated_legalize_to_standard.inc" } // end anonymous namespace namespace mhlo { namespace { @@ -176,15 +176,15 @@ class ConvertIotaOp : public OpRewritePattern { } // end anonymous namespace namespace { -struct LegalizeToStandard - : public PassWrapper { +struct LegalizeToStandardPass + : public PassWrapper { /// Perform the lowering to Standard dialect. void runOnFunction() override; }; } // end anonymous namespace std::unique_ptr> createLegalizeToStdPass() { - return std::make_unique(); + return std::make_unique(); } void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, @@ -194,14 +194,11 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, } /// Perform the lowering to standard dialect. -void LegalizeToStandard::runOnFunction() { +void LegalizeToStandardPass::runOnFunction() { OwningRewritePatternList patterns; mlir::mhlo::PopulateMhloToStdPatterns(&patterns, &getContext()); applyPatternsAndFoldGreedily(getFunction(), patterns); } -static PassRegistration legalize_pass( - "mhlo-legalize-to-std", "Legalize from MHLO dialect to standard dialect"); - } // end namespace mhlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td index 0e6fdf06701..ea67c052c5c 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td @@ -17,7 +17,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/IR/Ops.td" -include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" //===----------------------------------------------------------------------===// // Nullary op patterns. diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc index d2607887482..7a4418466b5 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc @@ -15,12 +15,11 @@ limitations under the License. // This file implements a pass to remove redundant LHLO copy operations. -#include "absl/memory/memory.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" namespace mlir { namespace lmhlo { @@ -30,7 +29,8 @@ namespace { // arguments. All uses of each buffer are replaced with the corresponding block // argument and the buffer is freed. Note that this pass only works in regions // with a single block. -struct LhloCopyRemoval : mlir::PassWrapper> { +struct LhloCopyRemovalPass + : mlir::PassWrapper> { void runOnOperation() override { llvm::SmallVector eraseList; auto operation = getOperation(); @@ -95,11 +95,8 @@ struct LhloCopyRemoval : mlir::PassWrapper> { } // namespace std::unique_ptr createLhloCopyRemovalPass() { - return absl::make_unique(); + return std::make_unique(); } -static PassRegistration copy_removal_pass( - "lhlo-copy-removal", "Removes redundant LHLO copy operations"); - } // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index d832b96bf7b..1467f015dc9 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -16,15 +16,14 @@ limitations under the License. // This file implements logic for fusing linalg ops obtained after LHLO // lowering. -#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" -#include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/FoldUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/FoldUtils.h" namespace mlir { namespace lmhlo { @@ -32,11 +31,13 @@ namespace { using linalg::LinalgOp; -class LhloFuseLinalg : public PassWrapper { +class LhloFuseLinalgPass + : public PassWrapper { public: - LhloFuseLinalg() = default; - LhloFuseLinalg(const LhloFuseLinalg&) {} - LhloFuseLinalg(bool use_parallel_loops, llvm::ArrayRef tile_sizes) { + LhloFuseLinalgPass() = default; + LhloFuseLinalgPass(const LhloFuseLinalgPass&) {} + LhloFuseLinalgPass(bool use_parallel_loops, + llvm::ArrayRef tile_sizes) { tile_sizes_ = tile_sizes; use_parallel_loops_.setValue(use_parallel_loops); } @@ -138,14 +139,10 @@ class LhloFuseLinalg : public PassWrapper { } // namespace -std::unique_ptr> createLhloFuseLinalg( +std::unique_ptr createLhloFuseLinalgPass( bool use_parallel_loops, ArrayRef tile_sizes) { - return absl::make_unique(use_parallel_loops, tile_sizes); + return std::make_unique(use_parallel_loops, tile_sizes); } -static PassRegistration legalize_pass( - "lhlo-fuse-linalg", - "Greedily fuse linalg ops obtained after LHLO lowering."); - } // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index a353472be4b..07891327775 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -15,17 +15,16 @@ limitations under the License. // This file implements logic for lowering LHLO dialect to Affine dialect. -#include "absl/memory/memory.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" namespace mlir { namespace lmhlo { @@ -138,8 +137,8 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, // clang-format on } -struct LhloLegalizeToAffine - : public PassWrapper { +struct LhloLegalizeToAffinePass + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; auto func = getFunction(); @@ -150,12 +149,9 @@ struct LhloLegalizeToAffine } // namespace -std::unique_ptr> createLegalizeToAffinePass() { - return absl::make_unique(); +std::unique_ptr> createLhloLegalizeToAffinePass() { + return std::make_unique(); } -static PassRegistration legalize_pass( - "lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect"); - } // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index 0ff491a93c3..cffb58b37de 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -17,25 +17,24 @@ limitations under the License. #include -#include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" -#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace lmhlo { @@ -148,9 +147,9 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { // Now copy over the actual body of the reduction, leaving out the // terminator. BlockAndValueMapping mapping; - mapping.map(reduce_op.body().front().getArgument(0), accumulator); - mapping.map(reduce_op.body().front().getArgument(1), rhs); - mapping.map(reduce_op.body().front().getArgument(2), accumulator); + mapping.map(reduce_op.body().getArgument(0), accumulator); + mapping.map(reduce_op.body().getArgument(1), rhs); + mapping.map(reduce_op.body().getArgument(2), accumulator); for (auto& nested : reduce_op.body().front().without_terminator()) { auto clone = rewriter.clone(nested, mapping); for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) { @@ -168,7 +167,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { }; }; -struct LhloLegalizeToGpu : public PassWrapper { +struct LhloLegalizeToGpuPass + : public PassWrapper { void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); @@ -185,12 +185,9 @@ struct LhloLegalizeToGpu : public PassWrapper { } // namespace -std::unique_ptr> createLegalizeToGpuPass() { - return absl::make_unique(); +std::unique_ptr createLegalizeToGpuPass() { + return std::make_unique(); } -static PassRegistration legalize_pass( - "lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect"); - } // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc index 32606f068a8..42b71543543 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace lmhlo { @@ -133,8 +133,8 @@ struct ReshapeMemRefCastOpConverter Location loc = op->getLoc(); auto reshape_op = cast(op); - Type dst_type = reshape_op.getResult().getType(); - auto element_type = dst_type.cast().getElementType(); + auto dst_type = reshape_op.getResult().getType().cast(); + auto element_type = dst_type.getElementType(); auto shape = reshape_op.shape(); @@ -162,18 +162,17 @@ struct ReshapeMemRefCastOpConverter desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr); desc.setOffset(rewriter, loc, ptrs_n_offset.offset); - auto llvmIndexTy = typeConverter.convertType(rewriter.getIndexType()) - .cast(); - auto llvmIndexTyPtr = llvmIndexTy.getPointerTo(); + auto llvm_index_type = typeConverter.getIndexType(); + auto llvm_index_ptr_type = llvm_index_type.getPointerTo(); Value stride_carried = rewriter.create( - loc, llvmIndexTy, + loc, llvm_index_type, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); for (int i = shape_length - 1; i >= 0; --i) { Value pos = rewriter.create( - loc, llvmIndexTy, + loc, llvm_index_type, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); Value ptr = rewriter.create( - loc, llvmIndexTyPtr, shape_desc.alignedPtr(rewriter, loc), + loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc), ValueRange{pos}); Value extracted_size = rewriter.create(loc, ptr); desc.setSize(rewriter, loc, i, extracted_size); @@ -188,7 +187,7 @@ struct ReshapeMemRefCastOpConverter rewriter.replaceOp(op, {desc}); } else { Value rank = rewriter.create( - loc, llvmIndexTy, + loc, llvm_index_type, rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length)); Value alloca = typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter); @@ -199,15 +198,126 @@ struct ReshapeMemRefCastOpConverter {rank, void_ptr}); rewriter.replaceOp(op, {unranked_desc}); } - } else { - /* - * TODO(pifon, herhut): - * Compute strides with llvm.loop; - * Use UnrankedMemrefDescr::ComputeSize with Alloca; - * Set all the fields using getelementptr. - */ - return failure(); + return success(); } + + // The shape is a rank-1 tensor with unknown length. + Value result_rank = shape_desc.size(rewriter, loc, 0); + // TODO(herhut): Propely handle address spaces. + unsigned address_space = 0; + auto target_type = + typeConverter + .convertType(UnrankedMemRefType::get(element_type, address_space)) + .cast(); + // Create the unranked memref descriptor that holds the ranked one. The + // inner descriptor is allocated on stack. + UnrankedMemRefDescriptor target_desc = + UnrankedMemRefDescriptor::undef(rewriter, loc, target_type); + target_desc.setRank(rewriter, loc, result_rank); + SmallVector sizes; + UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, + {target_desc}, sizes); + auto void_ptr_type = LLVM::LLVMType::getInt8PtrTy(rewriter.getContext()); + Value ranked_desc_mem = rewriter.create( + loc, void_ptr_type, sizes.front(), llvm::None); + target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem); + + // Fill the fixed parts. For this, we cast to a 0-D memref. + auto zero_d_memref_type = MemRefType::get({}, element_type); + Value as_zero_d = rewriter.create( + loc, + typeConverter.convertType(zero_d_memref_type) + .cast() + .getPointerTo(address_space), + ranked_desc_mem); + // Some common constants. Use 32 bit where required by gep struct indexes. + auto int32_type = typeConverter.convertType(rewriter.getI32Type()); + Value zero_index = rewriter.create( + loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0)); + Value zero = rewriter.create( + loc, int32_type, rewriter.getI32IntegerAttr(0)); + Value one = rewriter.create( + loc, int32_type, rewriter.getI32IntegerAttr(1)); + Value two = rewriter.create( + loc, int32_type, rewriter.getI32IntegerAttr(2)); + // Set base_pointer and aligned pointer. + auto element_ptr_ptr_type = typeConverter.convertType(element_type) + .cast() + .getPointerTo(address_space) + .getPointerTo(address_space); + auto base_gep = rewriter.create( + loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero})); + rewriter.create(loc, ptrs_n_offset.allocated_ptr, base_gep); + auto aligned_gep = rewriter.create( + loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one})); + rewriter.create(loc, ptrs_n_offset.aligned_ptr, aligned_gep); + // Set offset. + auto index_ptr_type = + typeConverter.getIndexType().getPointerTo(address_space); + auto offset_gep = rewriter.create( + loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two})); + rewriter.create(loc, ptrs_n_offset.offset, offset_gep); + + // Use the offset pointer as base for further addressing. Copy over the + // new shape and compute strides. For this, we need to create a loop from + // rank - 1 to 0. + Value one_index = rewriter.create( + loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1)); + auto target_shape_base = rewriter.create( + loc, index_ptr_type, offset_gep, ValueRange({one})); + auto target_strides_base = rewriter.create( + loc, index_ptr_type, target_shape_base, ValueRange({result_rank})); + auto shape_ptr = shape_desc.alignedPtr(rewriter, loc); + auto result_rank_minus_one = + rewriter.create(loc, result_rank, one_index); + + Block *init_block = rewriter.getInsertionBlock(); + Block *cond_block = + rewriter.splitBlock(init_block, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToEnd(init_block); + rewriter.create( + loc, ValueRange({result_rank_minus_one, one_index}), cond_block); + rewriter.setInsertionPointToStart(cond_block); + auto index_arg = cond_block->addArgument(typeConverter.getIndexType()); + auto stride_arg = cond_block->addArgument(typeConverter.getIndexType()); + auto pred = rewriter.create( + loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), + LLVM::ICmpPredicate::sge, index_arg, zero_index); + + Block *body_block = + rewriter.splitBlock(cond_block, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(body_block); + + // Copy size from shape to descriptor. + auto size_load_gep = rewriter.create( + loc, index_ptr_type, shape_ptr, ValueRange{index_arg}); + auto extracted_size = rewriter.create(loc, size_load_gep); + auto size_store_gep = rewriter.create( + loc, index_ptr_type, target_shape_base, ValueRange({index_arg})); + rewriter.create(loc, extracted_size, size_store_gep); + // Write stride value and compute next one. + auto stride_store_gep = rewriter.create( + loc, index_ptr_type, target_strides_base, ValueRange({index_arg})); + rewriter.create(loc, stride_arg, stride_store_gep); + auto next_stride = + rewriter.create(loc, stride_arg, extracted_size); + + // Decrement loop counter and branch back. + auto decrement = rewriter.create(loc, index_arg, one_index); + rewriter.create(loc, ValueRange({decrement, next_stride}), + cond_block); + + Block *remainder = + rewriter.splitBlock(body_block, rewriter.getInsertionPoint()); + + // Hook up the cond exit to the remainder. + rewriter.setInsertionPointToEnd(cond_block); + rewriter.create(loc, pred, body_block, ValueRange(), + remainder, ValueRange()); + + // Reset position to beginning of new remainder block. + rewriter.setInsertionPointToStart(remainder); + rewriter.replaceOp(op, {target_desc}); return success(); } @@ -250,11 +360,10 @@ struct ReshapeMemRefCastOpConverter } // namespace -void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, - LLVMTypeConverter *converter, +void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter, OwningRewritePatternList *patterns) { patterns->insert(*converter, options); + StaticMemRefCastOpConverter>(*converter); } } // namespace lmhlo diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc index d6cda99a912..8493a1feb5d 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc @@ -13,16 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" namespace mlir { namespace lmhlo { @@ -36,13 +34,9 @@ class TestLhloToLLVMPass ModuleOp m = getOperation(); OwningRewritePatternList patterns; - LLVMTypeConverter converter(m.getContext()); + LLVMTypeConverter converter(&getContext()); populateStdToLLVMConversionPatterns(converter, patterns); - PopulateLhloToLLVMConversionPatterns( - LowerToLLVMOptions::getDefaultOptions(), &converter, &patterns); - mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); - - mlir::populateAffineToStdConversionPatterns(patterns, m.getContext()); + PopulateLhloToLLVMConversionPatterns(&converter, &patterns); ConversionTarget target(getContext()); target.addLegalDialect(); @@ -57,8 +51,9 @@ class TestLhloToLLVMPass } // namespace -static PassRegistration legalize_lhlo_pass( - "test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM."); +std::unique_ptr createTestLhloToLLVMPass() { + return std::make_unique(); +} } // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index 4255d87d48e..19f47d08c0d 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -13,17 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace lmhlo { @@ -690,8 +689,8 @@ class SelectAndScatterOpConverter } }; -struct LhloLegalizeToParallelLoops - : public PassWrapper { +struct LhloLegalizeToParallelLoopsPass + : public PassWrapper { void runOnFunction() override { auto func = getFunction(); @@ -715,16 +714,11 @@ struct LhloLegalizeToParallelLoops } } }; - } // namespace std::unique_ptr> createLegalizeLhloToParallelLoopsPass() { - return absl::make_unique(); + return std::make_unique(); } -static PassRegistration legalize_lhlo_pass( - "lhlo-legalize-to-parallel-loops", - "Legalize from LHLO dialect to parallel loops."); - } // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc index 54ea4955573..9f7c946577d 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex.cc @@ -23,17 +23,17 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/utils/hlo_utils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" using mlir::FunctionPass; using mlir::OwningRewritePatternList; @@ -41,9 +41,9 @@ using mlir::PassRegistration; using mlir::PassWrapper; namespace { -class LowerComplex : public PassWrapper { +class LowerComplexPass : public PassWrapper { public: - explicit LowerComplex() : PassWrapper() {} + explicit LowerComplexPass() : PassWrapper() {} /// Performs the lowering to MHLO dialect. void runOnFunction() override; @@ -51,10 +51,10 @@ class LowerComplex : public PassWrapper { } // end anonymous namespace namespace mlir { -namespace hlo { +namespace mhlo { namespace { -#include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc" +#include "generated_lower_complex.inc" } // end anonymous namespace @@ -62,18 +62,18 @@ void PopulateComplexLoweringPatterns(MLIRContext* context, OwningRewritePatternList* patterns) { populateWithGenerated(context, patterns); } -} // end namespace hlo +} // end namespace mhlo } // end namespace mlir // Lowers the complex operations that can be represented using other operations. -void LowerComplex::runOnFunction() { +void LowerComplexPass::runOnFunction() { // Add lowering patterns to the list. OwningRewritePatternList patterns; - mlir::hlo::PopulateComplexLoweringPatterns(&getContext(), &patterns); + mlir::mhlo::PopulateComplexLoweringPatterns(&getContext(), &patterns); applyPatternsAndFoldGreedily(getFunction(), patterns); } -static PassRegistration pass( - "mhlo-test-lower-complex", - "Lower complex operations into non-complex operations"); +std::unique_ptr mlir::mhlo::createLowerComplexPass() { + return std::make_unique(); +} diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td index 0b72ccaa823..2cc97c90d1c 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_complex_patterns.td @@ -18,7 +18,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/Dialect/StandardOps/IR/Ops.td" -include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" //===----------------------------------------------------------------------===// // Binary op patterns. @@ -89,12 +89,10 @@ def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), // Absolute value is evaluated as: // result = sqrt(val.real * val.real + val.imag * val.imag) def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), - (HLO_ComplexOp (HLO_SqrtOp (HLO_AddOp (HLO_MulOp (HLO_RealOp:$real $val), $real), - (HLO_MulOp (HLO_ImagOp:$imag $val), $imag))), - (HLO_ConstOp (ConstantSplat<"0"> $real)))>; + (HLO_MulOp (HLO_ImagOp:$imag $val), $imag)))>; // Exponential can be lowered to an exponential on the real component and a // sum of sinusoids of the imaginary component, which equates to a normal diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc index 32a6ce42e5e..2bbd4691f95 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -17,18 +17,18 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" using mlir::DenseIntElementsAttr; using mlir::ElementsAttr; @@ -170,8 +170,8 @@ struct GeneralDotConvert : public OpRewritePattern { } }; -struct LegalizeGeneralDot - : public PassWrapper { +struct LegalizeGeneralDotPass + : public PassWrapper { /// Lower all general dots that can be represented as a non-batched matmul. void runOnFunction() override { OwningRewritePatternList patterns; @@ -187,6 +187,6 @@ void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns( patterns->insert(ctx); } -static PassRegistration legalize_pass( - "mhlo-test-lower-general-dot", - "Tests lowering general dot to a non-batched dot when possible"); +std::unique_ptr<::mlir::Pass> mlir::mhlo::createLegalizeGeneralDotPass() { + return std::make_unique(); +} diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc index c2f88ad5e31..445cf2e79fe 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc @@ -15,12 +15,12 @@ limitations under the License. #include -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc index 1d5d593bd43..3909f046007 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { @@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass } // namespace +std::unique_ptr<::mlir::Pass> createTestMaterializeBroadcastsPass() { + return std::make_unique(); +} + } // namespace mhlo } // namespace mlir - -static mlir::PassRegistration pass( - "mhlo-test-materialize-broadcasts", - "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_fusion.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_fusion.cc index 91f9344b8c5..233d95a1a65 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_fusion.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/mhlo_fusion.cc @@ -18,14 +18,14 @@ limitations under the License. #include #include +#include "llvm/ADT/EquivalenceClasses.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/utils/cycle_detector.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project #include "mlir/IR/MLIRContext.h" // TF:llvm-project #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" // TF:local_config_mlir #include "mlir/Transforms/RegionUtils.h" // TF:llvm-project -#include "llvm/ADT/EquivalenceClasses.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" // This pass has similar functionality of the fusion pass in XLA stack. // However, unlike XLA, it targets the fully dynamic shape scenario. @@ -479,7 +479,7 @@ class FusionPlanner { EquivalenceClasses leader_for_node_; }; -struct MhloFusion : public mlir::PassWrapper { +struct MhloFusionPass : public mlir::PassWrapper { void runOnFunction() override { FuncOp func = getFunction(); if (!IsTargetFunc(func)) { @@ -568,12 +568,9 @@ struct MhloFusion : public mlir::PassWrapper { } // namespace -std::unique_ptr> createMhloFusion() { - return std::make_unique(); +std::unique_ptr> createMhloFusionPass() { + return std::make_unique(); } -static PassRegistration mhlo_fusion_pass( - "mhlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns."); - } // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc new file mode 100644 index 00000000000..43de47086bf --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo.cc @@ -0,0 +1,187 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file provides optional optimization patterns for mhlo, canonocalizing +// operations to equivalent but potentially more efficient operations. + +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/utils/hlo_utils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +using mlir::OwningRewritePatternList; + +namespace mlir { +namespace mhlo { +namespace { + +// Returns 1D 64-bit dense elements attribute with the given values. +static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + +//===----------------------------------------------------------------------===// +// GatherOp +//===----------------------------------------------------------------------===// + +class GatherIsSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GatherOp gather, + PatternRewriter& rewriter) const override { + auto dimension_numbers = gather.dimension_numbers(); + + // Inputs need to be ranked to lower. + if (!gather.operand().getType().cast().hasRank() || + !gather.operand().getType().cast().hasStaticShape() || + !gather.start_indices().getType().cast().hasRank() || + !gather.start_indices().getType().cast().hasStaticShape()) { + return failure(); + } + + if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) { + return failure(); + } + + // TODO(suderman): Handle start index map != {0}. + if (!dimension_numbers.start_index_map() || + dimension_numbers.start_index_map().getType().getRank() != 1 || + dimension_numbers.start_index_map().getType().getDimSize(0) != 1 || + dimension_numbers.start_index_map() + .getValue({0}) + .cast() + .getValue() != 0) { + return failure(); + } + + auto result_ty = gather.getResult().getType().dyn_cast(); + + // Requires a ranked output. + if (!result_ty) { + return failure(); + } + if (dimension_numbers.offset_dims().getType().getNumElements() != + result_ty.getRank()) { + return failure(); + } + for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) { + if (it.index() != it.value()) { + return failure(); + } + } + + // Verify the gather slice sizes are correct. + if (gather.slice_sizes().getNumElements() != + gather.operand().getType().cast().getRank()) { + return failure(); + } + + // Validate the slice sizes are correct. + if (gather.slice_sizes().getType().cast().getNumElements() < + result_ty.getShape().size() + 1) { + return failure(); + } + + for (auto it : llvm::enumerate(result_ty.getShape())) { + if (gather.slice_sizes() + .getValue(it.index() + 1) + .cast() + .getValue() != it.value()) { + return failure(); + } + } + + auto gather_start_indices = gather.start_indices(); + auto gather_start_indices_ty = + gather_start_indices.getType().cast(); + + llvm::SmallVector slice_start_indices; + + if (gather_start_indices_ty.getRank() == 0) { + slice_start_indices.push_back(gather_start_indices); + } else if (gather_start_indices_ty.getRank() == 1) { + for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) { + auto start = GetI64ElementsAttr({i}, &rewriter); + auto limit = GetI64ElementsAttr({i + 1}, &rewriter); + auto stride = GetI64ElementsAttr({1}, &rewriter); + auto indicesSlice = rewriter.create( + gather.getLoc(), gather_start_indices, start, limit, stride); + auto reshaped = rewriter.create( + gather.getLoc(), + RankedTensorType::get( + {}, indicesSlice.getType().cast().getElementType()), + indicesSlice); + slice_start_indices.push_back(reshaped); + } + } else { + return failure(); + } + + auto sliceSizes = gather.slice_sizes(); + auto sliceSizesTy = sliceSizes.getType(); + if (sliceSizesTy.getRank() != 1) { + return failure(); + } + + // Start indices have implicit zeros when not specified. This is because + // Gather occurs similar to slicing where full slices are inferred. Add any + // missing zeros as necessary. + auto zero = rewriter.create( + gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get( + {}, gather_start_indices_ty.getElementType()))); + while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) { + slice_start_indices.push_back(zero); + } + + SmallVector sliceShape; + for (auto shapeValue : gather.slice_sizes().getIntValues()) { + sliceShape.push_back(shapeValue.getSExtValue()); + } + + auto sliceTy = + RankedTensorType::get(sliceShape, result_ty.getElementType()); + auto slice = rewriter.create( + gather.getLoc(), sliceTy, gather.operand(), slice_start_indices, + gather.slice_sizes()); + + rewriter.replaceOpWithNewOp(gather, gather.getType(), slice); + + return success(); + } +}; + +} // end anonymous namespace + +void PopulateOptimizeMHLOPatterns(MLIRContext* context, + OwningRewritePatternList* patterns) { + patterns->insert(context); +} +} // end namespace mhlo +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc new file mode 100644 index 00000000000..32a846e79ef --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using mlir::FunctionPass; +using mlir::PassRegistration; +using mlir::PassWrapper; + +namespace { +class OptimizeMhloPass : public PassWrapper { + public: + explicit OptimizeMhloPass() : PassWrapper() {} + + /// Performs the lowering to MHLO dialect. + void runOnFunction() override; +}; +} // end anonymous namespace + +// Lowers the complex operations that can be represented using other operations. +void OptimizeMhloPass::runOnFunction() { + // Add lowering patterns to the list. + mlir::OwningRewritePatternList patterns; + mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns); + + applyPatternsAndFoldGreedily(getFunction(), patterns); +} + +std::unique_ptr mlir::mhlo::createOptimizeMhloPass() { + return std::make_unique(); +} diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc index b05918030e9..8d677f45c19 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -15,12 +15,13 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Casting.h" -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir { namespace mhlo { @@ -29,8 +30,16 @@ namespace { // A pass that sinks constants implicitly captured in control flow regions. This // is necessary to export to XLA. -class SinkConstantsToControlFlow - : public mlir::PassWrapper { +// +// TODO(hinsu): Generalize this pass to handle all the ops with regions. Any +// value used within the region that is defined outside of op's region should be +// sank to the regions and not just the constants. Ops such as If and While +// whose computations doesn't require fixed signature like Sort or Reduce have +// an option to pass outside values as operands of the op to avoid recomputing +// those within internally. Note that doing so is the only option in case of +// values defined outside that are BlockArguments of any of the parent region. +class SinkConstantsToControlFlowPass + : public mlir::PassWrapper { void runOnFunction() override { getFunction().walk([](Operation* op) { if (auto while_op = llvm::dyn_cast(op)) { @@ -39,6 +48,8 @@ class SinkConstantsToControlFlow } else if (auto if_op = llvm::dyn_cast(op)) { SinkToRegion(&if_op.true_branch()); SinkToRegion(&if_op.false_branch()); + } else if (auto sort_op = llvm::dyn_cast(op)) { + SinkToRegion(&sort_op.comparator()); } }); } @@ -46,39 +57,36 @@ class SinkConstantsToControlFlow private: // Performs constant sinking into a region. static void SinkToRegion(Region* region) { - llvm::DenseMap sunk_constant; + llvm::DenseMap sunk_constant; visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { Value constant = use->get(); - auto const_op = dyn_cast_or_null(constant.getDefiningOp()); - if (!const_op) return; + auto op = constant.getDefiningOp(); + if (!op || !op->hasTrait()) return; auto map_entry = sunk_constant.try_emplace(constant, nullptr); if (!map_entry.second) { // This constant has already been cloned into the region, reuse it. - use->set(map_entry.first->getSecond().getResult()); - if (constant.use_empty()) const_op.erase(); + use->set(map_entry.first->getSecond()->getResult(0)); + if (op->use_empty()) op->erase(); return; } if (constant.hasOneUse()) { - const_op.getOperation()->moveBefore(®ion->front().front()); + op->moveBefore(®ion->front().front()); return; } - map_entry.first->getSecond() = const_op.clone(); + map_entry.first->getSecond() = op->clone(); region->front().getOperations().insert(region->front().begin(), map_entry.first->getSecond()); - use->set(map_entry.first->getSecond().getResult()); + use->set(map_entry.first->getSecond()->getResult(0)); }); } }; -static mlir::PassRegistration pass( - "mhlo-sink-constants-to-control-flow", - "Sink constants implicitly captured in control flow regions. This is " - "necessary to export to XLA."); - } // anonymous namespace +// TODO(hinsu): Rename this pass and move to a different file along with the +// generalization to make all ops isolated from above. std::unique_ptr> createSinkConstantsToControlFlowPass() { - return std::make_unique(); + return std::make_unique(); } } // namespace mhlo diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc index 184420bb8f7..35e5a184472 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Identifier.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" namespace mlir { -namespace hlo { +namespace mhlo { namespace { struct InferReturnTypeComponentsPattern : public RewritePattern { @@ -92,9 +92,10 @@ struct TestInferShapedTypeMethodsPass }; } // namespace -} // namespace hlo -} // namespace mlir -static mlir::PassRegistration pass( - "mhlo-test-infer-shaped-type-methods", - "Uses test ops to invoke InferShapedTypeOpInterface methods"); +std::unique_ptr createTestInferShapedTypeMethodsPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc index 53947855cc7..7c985ea7535 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/transform_unranked_hlo.cc @@ -14,24 +14,38 @@ limitations under the License. ==============================================================================*/ -#include "absl/memory/memory.h" -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { namespace { -// TODO(frgossen): Make it variadic. +// TODO(herhut): Generate these out of op definitions. +#define MAP_XLA_OPERATION_CWISE_UNARY(fn, sep) \ + fn(AbsOp) sep fn(CeilOp) sep fn(ClzOp) sep fn(CosOp) sep fn(ExpOp) \ + sep fn(Expm1Op) sep fn(FloorOp) sep fn(ImagOp) sep fn(IsFiniteOp) \ + sep fn(LogOp) sep fn(Log1pOp) sep fn(LogisticOp) sep fn(NotOp) \ + sep fn(NegOp) sep fn(PopulationCountOp) sep fn(RealOp) \ + sep fn(RoundOp) sep fn(RsqrtOp) sep fn(SignOp) sep fn(SinOp) \ + sep fn(SqrtOp) sep fn(TanhOp) + +// TODO(herhut): Generate these out of op definitions. +#define MAP_XLA_OPERATION_CWISE_BINARY(fn, sep) \ + fn(AddOp) sep fn(Atan2Op) sep fn(ComplexOp) sep fn(DivOp) sep fn(MaxOp) \ + sep fn(MinOp) sep fn(MulOp) sep fn(PowOp) sep fn(RemOp) \ + sep fn(ShiftLeftOp) sep fn(ShiftRightArithmeticOp) \ + sep fn(ShiftRightLogicalOp) sep fn(SubOp) + template inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { target->addDynamicallyLegalOp([](OpTy op) { @@ -60,29 +74,24 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern { // Generate IR to flatten the operand. auto loc = op.getLoc(); - Value shape = rewriter.create(loc, operand); - Value numElements = rewriter.create( - loc, rewriter.getType(), shape); - Value numElementsAsIndex = rewriter.create( - loc, rewriter.getIndexType(), numElements); - Value flatShapeAsDimTensor = - rewriter.create(loc, numElementsAsIndex); + Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); + Value shape = + rewriter.create(loc, extentTensorTy, operand); + Type indexTy = rewriter.getIndexType(); + Value numElements = + rewriter.create(loc, indexTy, shape); + Value flatShape = rewriter.create(loc, numElements); auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, operandTy.getElementType()); Value flatOperand = rewriter.create( - loc, flatTensorTy, operand, flatShapeAsDimTensor); + loc, flatTensorTy, operand, flatShape); // Generate IR for the actual operation. Value flatResult = rewriter.create(loc, flatTensorTy, flatOperand); // Generate IR to restore the original shape. - auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value shapeAsExtentTensor = - rewriter.create(loc, extentTensorTy, shape); - Value result = rewriter.create( - loc, operandTy, flatResult, shapeAsExtentTensor); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp(op, operandTy, + flatResult, shape); return success(); } @@ -108,17 +117,18 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern { } // Flatten operands. - Type shapeTy = shape::ShapeType::get(rewriter.getContext()); auto loc = op.getLoc(); - Value shapeLhs = rewriter.create(loc, op.lhs()); - Value shapeRhs = rewriter.create(loc, op.rhs()); - Value shape = rewriter.create(loc, shapeTy, + Type extentTensorTy = shape::getExtentTensorType(rewriter.getContext()); + Value shapeLhs = + rewriter.create(loc, extentTensorTy, op.lhs()); + Value shapeRhs = + rewriter.create(loc, extentTensorTy, op.rhs()); + Value shape = rewriter.create(loc, extentTensorTy, ValueRange{shapeLhs, shapeRhs}); - Value numElements = rewriter.create(loc, shape); - Value numElementsAsIndex = - rewriter.create(loc, numElements); - Value flatShape = - rewriter.create(loc, numElementsAsIndex); + Type indexTy = rewriter.getIndexType(); + Value numElements = + rewriter.create(loc, indexTy, shape); + Value flatShape = rewriter.create(loc, numElements); TensorType lhsTy = op.lhs().getType().template cast(); Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, lhsTy.getElementType()); @@ -134,13 +144,8 @@ struct BinaryElementwiseOpConversion : public OpRewritePattern { Value flatResult = rewriter.create(loc, flatLhs, flatRhs); // Restore original shape. - auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, - rewriter.getIndexType()); - Value shapeAsExtentTensor = - rewriter.create(loc, extentTensorTy, shape); - Value result = rewriter.create( - loc, op.getType(), flatResult, shapeAsExtentTensor); - rewriter.replaceOp(op, result); + rewriter.replaceOpWithNewOp(op, op.getType(), flatResult, + shape); return success(); } @@ -155,15 +160,17 @@ struct TransformUnrankedHloPass target.addLegalDialect(); target.addLegalOp(); - AddLegalOpOnRankedTensor(&target); - AddLegalOpOnRankedTensor(&target); +#define ADD_LEGAL(op) AddLegalOpOnRankedTensor(&target) + MAP_XLA_OPERATION_CWISE_UNARY(ADD_LEGAL, ;); + MAP_XLA_OPERATION_CWISE_BINARY(ADD_LEGAL, ;); +#undef ADD_LEGAL // Populate rewrite patterns. OwningRewritePatternList patterns; PopulateTransformUnrankedHloPatterns(&ctx, &patterns); // Apply transformation. - if (failed(applyFullConversion(getFunction(), target, patterns))) + if (failed(applyPartialConversion(getFunction(), target, patterns))) return signalPassFailure(); } }; @@ -174,15 +181,22 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { // TODO(frgossen): Populate all unary and binary operations. // clang-format off +#define MAP_UNARY(op) UnaryElementwiseOpConversion +#define MAP_BINARY(op) BinaryElementwiseOpConversion +#define COMMA , patterns->insert< - BinaryElementwiseOpConversion, - UnaryElementwiseOpConversion>(context); + MAP_XLA_OPERATION_CWISE_UNARY(MAP_UNARY, COMMA), + MAP_XLA_OPERATION_CWISE_BINARY(MAP_BINARY, COMMA) + >(context); +#undef MAP_UNARY +#undef MAP_BINARY +#undef COMMA // clang-format on } -static PassRegistration transform_unranked_hlo_pass( - "transform-unranked-hlo", - "Realize element-wise operations on ranked tensors where possible"); +std::unique_ptr<::mlir::Pass> createTransformUnrankedHloPass() { + return std::make_unique(); +} } // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc index 09c9c61119e..1458e5f3d63 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc @@ -14,15 +14,15 @@ limitations under the License. ==============================================================================*/ #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc index c26d73f3306..f187a7470cf 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace mhlo { @@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass } // namespace +std::unique_ptr<::mlir::Pass> createTestUnfuseBatchNormPass() { + return std::make_unique(); +} + } // namespace mhlo } // namespace mlir - -static mlir::PassRegistration pass( - "mhlo-test-unfuse-batch-norm", - "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/tensorflow/tools/dockerfiles/tests/import-mkl-horovod.sh b/tensorflow/compiler/mlir/hlo/lib/utils/CMakeLists.txt old mode 100755 new mode 100644 similarity index 58% rename from tensorflow/tools/dockerfiles/tests/import-mkl-horovod.sh rename to tensorflow/compiler/mlir/hlo/lib/utils/CMakeLists.txt index b1cae48c6ee..17e86f1caa8 --- a/tensorflow/tools/dockerfiles/tests/import-mkl-horovod.sh +++ b/tensorflow/compiler/mlir/hlo/lib/utils/CMakeLists.txt @@ -1,18 +1,25 @@ -#!/usr/bin/env bash - -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# ============================================================================ +# -python -c 'from tensorflow.python import pywrap_tensorflow; pywrap_tensorflow.IsMklEnabled() or exit(1); import horovod.tensorflow as hvd' +add_mlir_library(MLIRMhloUtils + broadcast_utils.cc + convert_op_folder.cc + cycle_detector.cc + hlo_utils.cc + + LINK_LIBS PUBLIC + MLIRSupport + ) diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc b/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc index e05ec3c3481..71b1a4e164f 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/broadcast_utils.cc @@ -13,15 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" +#include "mlir-hlo/utils/broadcast_utils.h" #include #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/StandardTypes.h" namespace mlir { namespace hlo { @@ -46,9 +47,9 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, broadcast_dims.getIntValues().begin()); } -Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, - Value rhs, - OpBuilder& builder) { +Value ComputeBinaryElementwiseBroadcastingResultExtents( + Location loc, Value lhs, Value rhs, OpBuilder& builder, + bool unsafe_as_extent_tensor) { auto lhs_type = lhs.getType().dyn_cast(); auto rhs_type = rhs.getType().dyn_cast(); if (!lhs_type || !rhs_type) { @@ -57,17 +58,22 @@ Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, return nullptr; } - int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); - auto shape_type = shape::ShapeType::get(builder.getContext()); - Value lhs_shape_v = - builder.createOrFold(loc, shape_type, lhs); - Value rhs_shape_v = - builder.createOrFold(loc, shape_type, rhs); - Value result_shape_v = builder.createOrFold( - loc, shape_type, lhs_shape_v, rhs_shape_v, nullptr /* error */); - return builder.createOrFold( - loc, RankedTensorType::get({result_rank}, builder.getIndexType()), - result_shape_v); + Value lhs_shape_v = builder.createOrFold(loc, lhs); + Value rhs_shape_v = builder.createOrFold(loc, rhs); + + if (unsafe_as_extent_tensor) { + int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); + Value result_shape_v = builder.createOrFold( + loc, shape::getExtentTensorType(builder.getContext()), lhs_shape_v, + rhs_shape_v, nullptr /* error */); + return builder.createOrFold( + loc, RankedTensorType::get({result_rank}, builder.getIndexType()), + result_shape_v); + } + + return builder.createOrFold( + loc, builder.getType(), lhs_shape_v, rhs_shape_v, + nullptr /* error */); } } // namespace hlo diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc b/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc index ea074c4907d..0751d2c626c 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/convert_op_folder.cc @@ -15,11 +15,11 @@ limitations under the License. // This file defines helpers useful when creating or manipulating lhlo/hlo. -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h" +#include "mlir-hlo/utils/convert_op_folder.h" -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Attributes.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" namespace mlir { namespace hlo { diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/cycle_detector.cc b/tensorflow/compiler/mlir/hlo/lib/utils/cycle_detector.cc index 6145391a608..0914460236d 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/cycle_detector.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/cycle_detector.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" +#include "mlir-hlo/utils/cycle_detector.h" #include diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/cycle_detector_test.cc b/tensorflow/compiler/mlir/hlo/lib/utils/cycle_detector_test.cc index 314bbd699c7..263321c17d1 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/cycle_detector_test.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/cycle_detector_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" +#include "mlir-hlo/utils/cycle_detector.h" #include "tensorflow/compiler/xla/test.h" diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc b/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc index 184d113fb9d..df2442cc4b6 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" +#include "mlir-hlo/utils/hlo_utils.h" #include -#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" namespace mlir { namespace hlo { diff --git a/tensorflow/compiler/mlir/hlo/tests/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/tests/CMakeLists.txt new file mode 100644 index 00000000000..36a7eec5a1f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/CMakeLists.txt @@ -0,0 +1,36 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +configure_lit_site_cfg( + ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in + ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py + MAIN_CONFIG + ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py +) + +set(MLIR_HLO_TEST_DEPENDS + FileCheck count not + mlir-hlo-opt +) + +add_lit_testsuite(check-mlir-hlo-lit "Running the mlir-hlo regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${MLIR_HLO_TEST_DEPENDS} + ) +set_target_properties(check-mlir-hlo-lit PROPERTIES FOLDER "Tests") + +add_lit_testsuites(MLIR_HLO_OPT ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${MLIR_HLO_TEST_DEPENDS}) + +add_dependencies(check-mlir-hlo check-mlir-hlo-lit) diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index 87774129ffb..15b1a150fdd 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -191,6 +191,20 @@ func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { return %2 : tensor<2x2xi32> } +// CHECK-LABEL: constant_like_constant +func @constant_like_constant(%arg0: tensor<3x4xi32>) -> tensor<3x4xf32> { + // CHECK: mhlo.constant dense<3.200000e+00> + %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<3x4xi32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// CHECK-LABEL: constant_like_constant_dynamic +func @constant_like_constant_dynamic(%arg0: tensor<*xi32>) -> tensor<*xf32> { + // CHECK: chlo.constant_like + %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<*xi32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: dynamic_slice_variable_start func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // CHECK: "mhlo.dynamic-slice" @@ -365,6 +379,25 @@ func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %ar return %0 : tensor<5x4xf32> } +// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_1 +func @dynamic_broadcast_in_dim_to_same_shape_1(%arg0: tensor) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor + %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> + %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor + return %2 : tensor +} + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_2 +func @dynamic_broadcast_in_dim_to_same_shape_2(%arg0: tensor) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor + %0 = shape.shape_of %arg0 : tensor -> !shape.shape + %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<1xindex> + %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor + return %2 : tensor +} + // CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { %cst = mhlo.constant dense<0.000000e+00> : tensor @@ -542,3 +575,25 @@ func @dce_while_without_side_effect(%arg0: tensor) -> tensor { return %arg0 : tensor } + +// CHECK-LABEL: unpack_repack_same_tuple +// CHECK-SAME: ([[ARG0:%.*]]: tuple, !mhlo.token, tensor>) +func @unpack_repack_same_tuple(%arg0: tuple, !mhlo.token, tensor>) -> tuple, !mhlo.token, tensor> { + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, !mhlo.token, tensor>) -> tensor + %1 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !mhlo.token, tensor>) -> !mhlo.token + %2 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, !mhlo.token, tensor>) -> tensor + %3 = "mhlo.tuple"(%0, %1, %2) : (tensor, !mhlo.token, tensor) -> tuple, !mhlo.token, tensor> + + // CHECK: return [[ARG0]] + return %3 : tuple, !mhlo.token, tensor> +} + +// CHECK-LABEL: unpack_repack_same_tuple_single_element +// CHECK-SAME: ([[ARG0:%.*]]: tuple>) +func @unpack_repack_same_tuple_single_element(%arg0: tuple>) -> tuple> { + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor + %3 = "mhlo.tuple"(%0) : (tensor) -> tuple> + + // CHECK: return [[ARG0]] + return %3 : tuple> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir index 65074325563..d226c92858a 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir @@ -5,15 +5,14 @@ // only test reification on an examplar op. // CHECK-SAME: %[[ARG0:.+]]: tensor, // CHECK-SAME: %[[ARG1:.+]]: tensor -func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xindex> { +func @broadcast_add(%arg0: tensor, %arg1: tensor) -> !shape.shape { // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] - // CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) - // CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]] - // CHECK: return %[[EXTENTS]] + // CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] : tensor, tensor -> !shape.shape + // CHECK: return %[[BCAST_S]] : !shape.shape %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor - %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> - return %1 : tensor<1xindex> + %1 = "mhlo_test.reify_return_type_shapes"(%0) : (tensor) -> !shape.shape + return %1 : !shape.shape } // ----- diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir index 2c0e2d7f170..9670372a864 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -18,8 +18,8 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]] @@ -39,8 +39,8 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) - // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] + // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> @@ -60,8 +60,8 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] - // CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) - // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]] + // CHECK: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_S]] : tensor to tensor<2xindex> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor // CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor @@ -237,3 +237,199 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } + +// ----- +func @addScalarUnranked(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor<*xf32>) + -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @addScalarUnranked( +// CHECK-SAME: %[[ARG_0:.*]]: tensor, +// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32> +// CHECK-SAME: ) -> tensor<*xf32> { +// First handle the dynamic reshaping of the unranked operand +// to a 1D tensor. +// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32> +// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]] : tensor -> index +// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> +// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// The assuming region is part of the second stage of lowering +// with ranked broadcasting logic. +// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor +// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor +// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]] +// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { +// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape [] +// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]] +// CHECK: %[[SHAPE_TENSOR:.*]] = tensor_cast %[[BROADCASTED_SHAPE]] : tensor to tensor<1xindex> +// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor +// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor +// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor +// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor +// CHECK: } +// As part of the unranked logic, the result is reshaped back +// to an unranked tensor. +// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_1]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32> +// CHECK: } + +// ----- +func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor<*xf32> { + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor) + -> tensor<*xf32> + return %0 : tensor<*xf32> +} +// CHECK-LABEL: func @addUnrankedScalar( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: tensor) -> tensor<*xf32> { +// First handle the dynamic reshaping of the unranked operand +// to a 1D tensor. +// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32> +// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]] : tensor -> index +// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> +// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor +// The assuming region is part of the second stage of lowering +// with ranked broadcasting logic. +// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor +// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor +// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]] +// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor) { +// CHECK: %[[ASTENSOR:.*]] = tensor_cast %[[SHAPE_RESHAPED]] +// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[ASTENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor, tensor<1xindex>) -> tensor +// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[ASTENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor +// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor +// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor +// CHECK: } +// As part of the unranked logic, the result is reshaped back +// to an unranked tensor. +// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_0]]) : (tensor, tensor) -> tensor<*xf32> +// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32> +// CHECK: } + +// ----- +func @addUnrankedUnranked( + %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<*xf32>) + -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @addUnrankedUnranked( +// CHECK-SAME: %[[LHS:.*]]: tensor<*xf32>, +// CHECK-SAME: %[[RHS:.*]]: tensor<*xf32>) -> tensor<*xf32> { +// CHECK: %[[LHS_SHAPE:.*]] = shape.shape_of %[[LHS]] : tensor<*xf32> -> tensor +// CHECK: %[[RANK_LHS:.*]] = shape.rank %[[LHS_SHAPE]] : tensor -> index +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[LHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_LHS]], %[[C0]] : index +// Handle scalar LHS case +// CHECK: %[[VAL_8:.*]] = scf.if %[[LHS_IS_SCALAR]] -> (tensor<*xf32>) { +// CHECK: %[[SCALAR_LHS:.*]] = "mhlo.reshape"(%[[LHS]]) : (tensor<*xf32>) -> tensor +// CHECK: %[[VAL_10:.*]] = chlo.broadcast_add %[[SCALAR_LHS]], %[[RHS]] : (tensor, tensor<*xf32>) -> tensor<*xf32> +// CHECK: scf.yield %[[VAL_10]] : tensor<*xf32> +// CHECK: } else { +// CHECK: %[[RHS_SHAPE:.*]] = shape.shape_of %[[RHS]] : tensor<*xf32> -> tensor +// CHECK: %[[RANK_RHS:.*]] = shape.rank %[[RHS_SHAPE]] : tensor -> index +// CHECK: %[[RHS_IS_SCALAR:.*]] = cmpi "eq", %[[RANK_RHS]], %[[C0]] : index + // Handle scalar RHS case +// CHECK: %[[VAL_14:.*]] = scf.if %[[RHS_IS_SCALAR]] -> (tensor<*xf32>) { +// CHECK: %[[SCALAR_RHS:.*]] = "mhlo.reshape"(%[[RHS]]) : (tensor<*xf32>) -> tensor +// CHECK: %[[VAL_16:.*]] = chlo.broadcast_add %[[LHS]], %[[SCALAR_RHS]] : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: scf.yield %[[VAL_16]] : tensor<*xf32> +// CHECK: } else { +// CHECK: %[[SHAPES_EQ:.*]] = shape.shape_eq %[[LHS_SHAPE]], %[[RHS_SHAPE]] : tensor, tensor + // Handle scalar RHS case +// CHECK: %[[VAL_18:.*]] = scf.if %[[SHAPES_EQ]] -> (tensor<*xf32>) { +// CHECK: %[[VAL_19:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor<*xf32> +// CHECK: scf.yield %[[VAL_19]] : tensor<*xf32> +// CHECK: } else { +// CHECK: %[[LHS_RANK:.*]] = rank %[[LHS_SHAPE]] : tensor +// CHECK: %[[RHS_RANK:.*]] = rank %[[RHS_SHAPE]] : tensor +// CHECK: %[[LHS_RANK_GREATER:.*]] = cmpi "sgt", %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK: %[[GREATEST_RANK:.*]] = select %[[LHS_RANK_GREATER]], %[[LHS_RANK]], %[[RHS_RANK]] : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[GREATEST_RANK_IS_2:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C2]] : index +// Handle rank 2 specialization +// CHECK: %[[VAL_26:.*]] = scf.if %[[GREATEST_RANK_IS_2]] -> (tensor<*xf32>) { +// CHECK: %[[CONST_SHAPE_2:.*]] = shape.const_shape [1, 1] +// CHECK: %[[BROADCASTED_LHS_2:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor<2xindex> +// CHECK: %[[BROADCASTED_RHS_2:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_2]] : tensor, tensor<2xindex> -> tensor<2xindex> +// CHECK: %[[RESHAPED_LHS_2:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_2:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_2]]) : (tensor<*xf32>, tensor<2xindex>) -> tensor +// CHECK: %[[RESULT_RANK_2:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_2]], %[[RESHAPED_RHS_2]] : (tensor, tensor) -> tensor +// CHECK: %[[RESULT_2:.*]] = tensor_cast %[[RESULT_RANK_2]] : tensor to tensor<*xf32> +// CHECK: scf.yield %[[RESULT_2]] : tensor<*xf32> +// CHECK: } else { +// CHECK: %[[C3:.*]] = constant 3 : index +// CHECK: %[[GREATEST_RANK_IS_3:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C3]] : index +// Handle rank 3 specialization +// CHECK: %[[VAL_34:.*]] = scf.if %[[GREATEST_RANK_IS_3]] -> (tensor<*xf32>) { +// CHECK: %[[CONST_SHAPE_3:.*]] = shape.const_shape [1, 1, 1] +// CHECK: %[[BROADCASTED_LHS_3:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor<3xindex> +// CHECK: %[[BROADCASTED_RHS_3:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_3]] : tensor, tensor<3xindex> -> tensor<3xindex> +// CHECK: %[[RESHAPED_LHS_3:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_3:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_3]]) : (tensor<*xf32>, tensor<3xindex>) -> tensor +// CHECK: %[[RESULT_RANK_3:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_3]], %[[RESHAPED_RHS_3]] : (tensor, tensor) -> tensor +// CHECK: %[[RESULT_3:.*]] = tensor_cast %[[RESULT_RANK_3]] : tensor to tensor<*xf32> +// CHECK: scf.yield %[[RESULT_3]] : tensor<*xf32> +// CHECK: } else { +// CHECK: %[[C4:.*]] = constant 4 : index +// CHECK: %[[GREATEST_RANK_IS_4:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C4]] : index +// Handle rank 4 specialization +// CHECK: %[[VAL_42:.*]] = scf.if %[[GREATEST_RANK_IS_4]] -> (tensor<*xf32>) { +// CHECK: %[[CONST_SHAPE_4:.*]] = shape.const_shape [1, 1, 1, 1] +// CHECK: %[[BROADCASTED_LHS_4:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor<4xindex> +// CHECK: %[[BROADCASTED_RHS_4:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_4]] : tensor, tensor<4xindex> -> tensor<4xindex> +// CHECK: %[[RESHAPED_LHS_4:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_4:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_4]]) : (tensor<*xf32>, tensor<4xindex>) -> tensor +// CHECK: %[[RESULT_RANK_4:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_4]], %[[RESHAPED_RHS_4]] : (tensor, tensor) -> tensor +// CHECK: %[[RESULT_4:.*]] = tensor_cast %[[RESULT_RANK_4]] : tensor to tensor<*xf32> +// CHECK: scf.yield %[[RESULT_4]] : tensor<*xf32> +// CHECK: } else { +// CHECK: %[[C5:.*]] = constant 5 : index +// CHECK: %[[GREATEST_RANK_IS_5:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C5]] : index +// Handle rank 5 specialization +// CHECK: %[[VAL_50:.*]] = scf.if %[[GREATEST_RANK_IS_5]] -> (tensor<*xf32>) { +// CHECK: %[[CONST_SHAPE_5:.*]] = shape.const_shape [1, 1, 1, 1, 1] +// CHECK: %[[BROADCASTED_LHS_5:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor<5xindex> +// CHECK: %[[BROADCASTED_RHS_5:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_5]] : tensor, tensor<5xindex> -> tensor<5xindex> +// CHECK: %[[RESHAPED_LHS_5:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_5:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_5]]) : (tensor<*xf32>, tensor<5xindex>) -> tensor +// CHECK: %[[RESULT_RANK_5:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_5]], %[[RESHAPED_RHS_5]] : (tensor, tensor) -> tensor +// CHECK: %[[RESULT_5:.*]] = tensor_cast %[[RESULT_RANK_5]] : tensor to tensor<*xf32> +// CHECK: scf.yield %[[RESULT_5]] : tensor<*xf32> +// CHECK: } else { +// CHECK: %[[C6:.*]] = constant 6 : index +// CHECK: %[[GREATEST_RANK_IS_6:.*]] = cmpi "eq", %[[GREATEST_RANK]], %[[C6]] : index +// Handle rank 6 specialization +// CHECK: %[[VAL_58:.*]] = scf.if %[[GREATEST_RANK_IS_6]] -> (tensor<*xf32>) { +// CHECK: %[[CONST_SHAPE_6:.*]] = shape.const_shape [1, 1, 1, 1, 1, 1] +// CHECK: %[[BROADCASTED_LHS_6:.*]] = shape.broadcast %[[LHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor<6xindex> +// CHECK: %[[BROADCASTED_RHS_6:.*]] = shape.broadcast %[[RHS_SHAPE]], %[[CONST_SHAPE_6]] : tensor, tensor<6xindex> -> tensor<6xindex> +// CHECK: %[[RESHAPED_LHS_6:.*]] = "mhlo.dynamic_reshape"(%[[LHS]], %[[BROADCASTED_LHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor +// CHECK: %[[RESHAPED_RHS_6:.*]] = "mhlo.dynamic_reshape"(%[[RHS]], %[[BROADCASTED_RHS_6]]) : (tensor<*xf32>, tensor<6xindex>) -> tensor +// CHECK: %[[RESULT_RANK_6:.*]] = chlo.broadcast_add %[[RESHAPED_LHS_6]], %[[RESHAPED_RHS_6]] : (tensor, tensor) -> tensor +// CHECK: %[[RESULT_6:.*]] = tensor_cast %[[RESULT_RANK_6]] : tensor to tensor<*xf32> +// CHECK: scf.yield %[[RESULT_6]] : tensor<*xf32> +// CHECK: } else { +// CHECK: %false = constant false +// CHECK: assert %false +// CHECK: scf.yield %[[LHS]] : tensor<*xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_64:.*]] : tensor<*xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_65:.*]] : tensor<*xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_66:.*]] : tensor<*xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_67:.*]] : tensor<*xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_68:.*]] : tensor<*xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_69:.*]] : tensor<*xf32> +// CHECK: } +// CHECK: scf.yield %[[VAL_70:.*]] : tensor<*xf32> +// CHECK: } +// CHECK: return %[[VAL_71:.*]] : tensor<*xf32> +// CHECK: } diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-gather-to-torch-index-select.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-gather-to-torch-index-select.mlir new file mode 100644 index 00000000000..ca90a80aa6c --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-gather-to-torch-index-select.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-hlo-opt -mhlo-legalize-gather-to-torch-index-select %s -o - | FileCheck %s + +// CHECK-LABEL: @gather_to_index_select +func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> { + // CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: batch_dims = 0 : i64, + // CHECK-SAME: dim = 0 : i64 + // CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32> + // CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]]) + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32> + + // CHECK: return [[RES]] + return %0 : tensor<1x3x4xf32> +} + +// CHECK-LABEL: @scalar_gather_to_index_select +func @scalar_gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor) -> tensor<1x4xf32> { + // CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) { + // CHECK-SAME: batch_dims = 0 : i64, + // CHECK-SAME: dim = 0 : i64 + // CHECK-SAME: } : (tensor<5x4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]]) + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor) -> tensor<1x4xf32> + + // CHECK: return [[RES]] + return %0 : tensor<1x4xf32> +} + +// CHECK-LABEL: @gather_no_lowering_subslice +func @gather_no_lowering_subslice(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x3xf32> { + // CHECK: "mhlo.gather" + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32> + return %0 : tensor<1x3x3xf32> +} + +// CHECK-LABEL: @gather_no_lowering_multidim +func @gather_no_lowering_multidim(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x2xi32>) -> tensor<1x3x4xf32> { + // CHECK: "mhlo.gather" + %0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32> + return %0 : tensor<1x3x4xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index aa5d800b82b..018711e33cb 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s -// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s // BOTH-LABEL: func @attrs func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { @@ -487,3 +487,26 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> return %out : tensor<3x5x5x4xf32> } + +// ----- + +// BOTH-LABEL: func @reduce +func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor) -> tensor<1xf32> { + // BOTH: %[[OUT:.*]] = alloc() : memref<1xf32> + // BOTH: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( { + // BOTH: ^bb0(%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, + // BOTH-SAME: %[[ARG3:.*]]: memref): + // BOTH: %[[TMP:.*]] = alloc() : memref + // BOTH: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]]) + // BOTH: "lmhlo.copy"(%[[TMP]], %[[ARG3]]) + // BOTH: "lmhlo.terminator"() : () -> () + // BOTH: }) {dimensions = dense<1> : tensor<1xi64>} + // BOTH-SAME: : (memref<1x8xf32>, memref, memref<1xf32>) -> () + %0 = "mhlo.reduce"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} + : (tensor<1x8xf32>, tensor) -> tensor<1xf32> + return %0 : tensor<1xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index 320ce069ac0..46725e0bd09 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -hlo-legalize-to-linalg -split-input-file | FILECHECK_OPTS="" FileCheck %s // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @float_add @@ -557,3 +557,18 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { } // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] + +// ----- + +// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @iota +func @iota() -> tensor<7x10xf32> { + %result = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> (tensor<7x10xf32>) + return %result : tensor<7x10xf32> +} +// CHECK: linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index): +// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 +// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 +// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir index 6d7992cb868..3271595900d 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir @@ -91,3 +91,25 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>, dealloc %0 : memref<2x2xf32> "lmhlo.terminator"() : () -> () } + +// ----- + +// CHECK-LABEL: func @reduce +func @reduce(%arg0: memref<1x8xf32>, %arg1: memref, %arg2: memref<1xf32>) { + %0 = alloc() : memref<1xf32> + "lmhlo.reduce"(%arg0, %arg1, %0) ( { + // CHECK: ^bb0(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, + // CHECK-SAME: %[[ARG2:.*]]: memref) + ^bb0(%arg3: memref, %arg4: memref, %arg5: memref): + %1 = alloc() : memref + // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) + "lmhlo.add"(%arg3, %arg4, %1) + : (memref, memref, memref) -> () + // CHECK-NOT; lmhlo.copy + "lmhlo.copy"(%1, %arg5) : (memref, memref) -> () + "lmhlo.terminator"() : () -> () + }) {dimensions = dense<1> : tensor<1xi64>} + : (memref<1x8xf32>, memref, memref<1xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> () + return +} diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir index dd88e5c80bf..768d8da22bd 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FILECHECK_OPTS="" FileCheck %s // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @element_wise diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir index a25a508b2d3..45c383bd1d6 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -lower-affine -convert-scf-to-std -test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s // CHECK-LABEL: func @static_memref_cast func @static_memref_cast(%buf : memref<10x1x5xf32>) { @@ -11,11 +11,11 @@ func @static_memref_cast(%buf : memref<10x1x5xf32>) { // CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_2D:!.*]] // CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE_3D]] -// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr to !llvm.ptr // CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE_2D]] // CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE_3D]] -// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr to !llvm.ptr // CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE_2D]] // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 @@ -50,11 +50,11 @@ func @dynamic_memref_cast(%buf : memref) { // CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]] // CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]] -// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr to !llvm.ptr // CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]] // CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]] -// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr to !llvm.ptr // CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]] // CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]] diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir index 1530f59317d..47ef99bcac0 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s +// RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FILECHECK_OPTS="" FileCheck %s func @reduce(%arg: memref<100x10x5xf32>, %init: memref, diff --git a/tensorflow/compiler/mlir/hlo/tests/lit.cfg.py b/tensorflow/compiler/mlir/hlo/tests/lit.cfg.py new file mode 100644 index 00000000000..f81d47a76cd --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/lit.cfg.py @@ -0,0 +1,82 @@ +"""Lit configuration to drive test in this repo.""" +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- Python -*- +# pylint: disable=undefined-variable + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import lit.formats +from lit.llvm import llvm_config +from lit.llvm.subst import ToolSubst +import lit.util + +# Configuration file for the 'lit' test runner. + +# name: The name of this test suite. +config.name = 'MLIR_HLO_OPT' + +config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) + +# suffixes: A list of file extensions to treat as test files. +config.suffixes = ['.mlir', '.mlir.py'] + +# test_source_root: The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.mlir_hlo_obj_root, 'test') + +config.substitutions.append(('%PATH%', config.environment['PATH'])) +config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) + +llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) + +llvm_config.use_default_substitutions() + +# excludes: A list of directories to exclude from the testsuite. The 'Inputs' +# subdirectories contain auxiliary inputs for various tests in their parent +# directories. +config.excludes = [ + 'Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt' +] + +# test_source_root: The root path where tests are located. +config.test_source_root = os.path.dirname(__file__) + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.path.join(config.mlir_hlo_obj_root, 'test') +config.mlir_hlo_tools_dir = os.path.join(config.mlir_hlo_obj_root, 'tools') + +# Tweak the PATH to include the tools dir. +llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) + +tool_dirs = [ + os.path.join(config.mlir_hlo_tools_dir, 'mlir-hlo-opt'), + config.llvm_tools_dir, +] +tools = [ + 'mlir-hlo-opt', + 'mlir-cpu-runner', + ToolSubst( + '%mlir_runner_utils_dir', + config.mlir_runner_utils_dir, + unresolved='ignore'), +] + +llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/hlo/tests/lit.site.cfg.py.in b/tensorflow/compiler/mlir/hlo/tests/lit.site.cfg.py.in new file mode 100644 index 00000000000..1555d314df0 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/lit.site.cfg.py.in @@ -0,0 +1,63 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +@LIT_SITE_CFG_IN_HEADER@ + +import sys + +config.host_triple = "@LLVM_HOST_TRIPLE@" +config.target_triple = "@TARGET_TRIPLE@" +config.llvm_src_root = "@LLVM_SOURCE_DIR@" +config.llvm_obj_root = "@LLVM_BINARY_DIR@" +config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" +config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@" +config.llvm_shlib_dir = "@SHLIBDIR@" +config.llvm_shlib_ext = "@SHLIBEXT@" +config.llvm_exe_ext = "@EXEEXT@" +config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" +config.python_executable = "@PYTHON_EXECUTABLE@" +config.gold_executable = "@GOLD_EXECUTABLE@" +config.ld64_executable = "@LD64_EXECUTABLE@" +config.enable_shared = @ENABLE_SHARED@ +config.enable_assertions = @ENABLE_ASSERTIONS@ +config.targets_to_build = "@TARGETS_TO_BUILD@" +config.native_target = "@LLVM_NATIVE_ARCH@" +config.llvm_bindings = "@LLVM_BINDINGS@".split(' ') +config.host_os = "@HOST_OS@" +config.host_cc = "@HOST_CC@" +config.host_cxx = "@HOST_CXX@" +# Note: ldflags can contain double-quoted paths, so must use single quotes here. +config.host_ldflags = '@HOST_LDFLAGS@' +config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" +config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' +config.host_arch = "@HOST_ARCH@" +config.mlir_hlo_src_root = "@CMAKE_SOURCE_DIR@" +config.mlir_hlo_obj_root = "@CMAKE_BINARY_DIR@" +config.mlir_runner_utils_dir = os.path.join(config.llvm_obj_root, "lib") + +# Support substitution of the tools_dir with user parameters. This is +# used when we can't determine the tool dir at configuration time. +try: + config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params + config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params +except KeyError: + e = sys.exc_info()[1] + key, = e.args + lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key)) + + +import lit.llvm +lit.llvm.initialize(lit_config, config) + +# Let the main config do the real work. +lit_config.load_config(config, "@CMAKE_SOURCE_DIR@/tests/lit.cfg.py") diff --git a/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir b/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir index 8d84e7140f3..a7bd21257a6 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir @@ -182,11 +182,10 @@ func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { // CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1 // CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]] // CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]]) - %1 = "mhlo.abs"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) - %2 = "mhlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %1 = "mhlo.abs"(%0) : (tensor<2xcomplex>) -> (tensor<2xf32>) // CHECK: return [[VAL3]] - return %2 : tensor<2xf32> + return %1 : tensor<2xf32> } // CHECK-LABEL: @exp diff --git a/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir index 80474156f29..56a7cf7294c 100644 --- a/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir @@ -5,10 +5,9 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { // Flatten operand shape. - %shape = shape.shape_of %a : tensor<*xf32> - %num_elements = shape.num_elements %shape - %num_elements_as_index = shape.size_to_index %num_elements - %flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex> + %shape = shape.shape_of %a : tensor<*xf32> -> tensor + %num_elements = shape.num_elements %shape : tensor -> index + %flat_shape = tensor_from_elements(%num_elements) : tensor<1xindex> %flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape) : (tensor<*xf32>, tensor<1xindex>) -> tensor @@ -16,8 +15,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { %flat_b = "mhlo.sqrt"(%flat_a) : (tensor) -> tensor // Restore original shape. - %shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor - %b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) + %b = "mhlo.dynamic_reshape"(%flat_b, %shape) : (tensor, tensor) -> tensor<*xf32> return %b : tensor<*xf32> @@ -29,14 +27,12 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: @sqrt // CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { - // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] - // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> // CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor - // CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor - // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK-NEXT: return %[[B]] : tensor<*xf32> %b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> return %b : tensor<*xf32> @@ -73,15 +69,13 @@ func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SHAPE_A:.*]] = shape.shape_of %[[A]] // CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]] - // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]] + // CHECK: %[[SHAPE:.*]] = "shape.any"(%[[SHAPE_A]], %[[SHAPE_B]]) // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] - // CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] - // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor // CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor - // CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor, tensor) -> tensor<*xf32> // CHECK: return %[[RESULT]] : tensor<*xf32> %result = mhlo.add %a, %b : tensor<*xf32> return %result : tensor<*xf32> diff --git a/tensorflow/compiler/mlir/hlo/tests/ops.mlir b/tensorflow/compiler/mlir/hlo/tests/ops.mlir index b46827b88a5..a8f16c403ae 100644 --- a/tensorflow/compiler/mlir/hlo/tests/ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/ops.mlir @@ -116,6 +116,30 @@ func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64>) -> // ----- +// CHECK-LABEL: func @dynamic_broadcast_in_dim_unknown_dim +func @dynamic_broadcast_in_dim_unknown_dim(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_ok_dim +func @dynamic_broadcast_in_dim_ok_dim(%arg0: tensor<1xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<1xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> + return %0 : tensor<7x8x9xf32> +} + +// ----- + +func @dynamic_broadcast_in_dim_shape_mismatch(%arg0: tensor<32xf32>, %shape: tensor<3xi64>) -> tensor<7x8x9xf32> { + // expected-error@+1 {{size of operand dimension 0 (32) is not compatible with size of result dimension 2 (9)}} + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[2]> : tensor<1xi64>} : (tensor<32xf32>, tensor<3xi64>) -> tensor<7x8x9xf32> + return %0 : tensor<7x8x9xf32> +} + +// ----- + func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> @@ -456,7 +480,7 @@ func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: tensor<4 // expected-error@+1 {{computation arguments must be 0-rank tensor, but got: arg #1 of type 'tensor<5xf32>'}} %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor<5xf32>): - %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + %1 = mhlo.constant dense<2.0> : tensor "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> @@ -468,7 +492,7 @@ func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %arg1: t // expected-error@+1 {{element type of operands and computation arguments must match, but got: 'f32' and 'i32'}} %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + %1 = mhlo.constant dense<2.0> : tensor "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> @@ -480,7 +504,7 @@ func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg1: te // expected-error@+1 {{computation must return single output, but got: 0}} %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor + %1 = mhlo.constant dense<2.0> : tensor "mhlo.return"() : () -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> @@ -492,7 +516,7 @@ func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4 // expected-error@+1 {{computation must return 0-rank tensor, but got: 'tensor<5xf32>'}} %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = mhlo.constant {value = dense<2.0> : tensor} : tensor<5xf32> + %1 = mhlo.constant dense<2.0> : tensor<5xf32> "mhlo.return"(%1) : (tensor<5xf32>) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> @@ -504,7 +528,7 @@ func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5 // expected-error@+1 {{element type of result and computation output must match, but got: 'f32' and 'i32'}} %0 = "mhlo.map"(%arg0, %arg1) ( { ^bb0(%arg2: tensor, %arg3: tensor): - %1 = mhlo.constant {value = dense<2> : tensor} : tensor + %1 = mhlo.constant dense<2> : tensor "mhlo.return"(%1) : (tensor) -> () }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> return %0 : tensor<4x5xf32> @@ -730,6 +754,14 @@ func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tenso // ----- +func @dynamic_update_slice_mismatched_start(%input: tensor<11x3x4xi32>, %update: tensor<1x3x4xi32>, %start1: tensor, %start2: tensor, %start3: tensor) -> tensor<11x3x4xi32> { + // expected-error@+1 {{start indices must have same element type (encountered mismatch: 'i32' vs 'i64')}} + %0 = "mhlo.dynamic-update-slice"(%input, %update, %start1, %start2, %start3) : (tensor<11x3x4xi32>, tensor<1x3x4xi32>, tensor, tensor, tensor) -> tensor<11x3x4xi32> + return %0 : tensor<11x3x4xi32> +} + +// ----- + // CHECK-LABEL: func @transpose func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> @@ -847,6 +879,13 @@ func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple // ----- +func @tuple_token(%arg0: tensor, %arg1: !mhlo.token) -> tuple, !mhlo.token> { + %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, !mhlo.token) -> tuple, !mhlo.token> + return %0 : tuple, !mhlo.token> +} + +// ----- + func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, tensor, tensor> { // expected-error@+1 {{has return type tuple, tensor, tensor>, but expected tuple, tensor>}} %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor, tensor> @@ -939,7 +978,23 @@ func @constants() -> () { func @constant_invalid() -> () { // expected-error@+1 {{op failed to verify that all of {value, output} have same type}} - %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor<*xi32>) + %0 = "mhlo.constant"() {value = dense<0> : tensor} : () -> (tensor<3xi32>) + return +} + +// ----- + +func @constant_invalid() -> () { + // expected-error@+1 {{op result #0 must be statically shaped tensor}} + %0 = "mhlo.constant"() {value = dense<1> : tensor} : () -> tensor + return +} + +// ----- + +func @constant_invalid() -> () { + // expected-error@+1 {{elements literal type must have static shape}} + %0 = "mhlo.constant"() {value = dense<1> : tensor} : () -> tensor return } diff --git a/tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir b/tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir new file mode 100644 index 00000000000..c20de0b2a9f --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/optimize-hlo.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-hlo-opt %s -pass-pipeline='func(mhlo-test-optimize)' | FileCheck %s + +// CHECK-LABEL: @gather_is_slice_no_rank +func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor) -> tensor<1x2xi32> { + // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor + // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} + // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"([[SLICE]]) + %res = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = { + collapsed_slice_dims = dense<0> : tensor<1xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1]> : tensor<2xi64>, + start_index_map = dense<0> : tensor<1xi64> + }, + slice_sizes = dense<[1, 1, 2]> : tensor<3xi64> + } : (tensor<2x1x2xi32>, tensor) -> tensor<1x2xi32> + + // CHECK: return [[RESHAPE]] + return %res : tensor<1x2xi32> +} + +// CHECK-LABEL: @gather_is_slice +func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> { + // CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor + // CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"(%arg1) + // CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} + // CHECK: [[RES:%.+]] = "mhlo.reshape"([[SLICE]]) + + %res = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = { + collapsed_slice_dims = dense<0> : tensor<1xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1]> : tensor<2xi64>, + start_index_map = dense<0> : tensor<1xi64> + }, + slice_sizes = dense<[1, 1, 2]> : tensor<3xi64> + } : (tensor<2x1x2xi32>, tensor<1xi64>) -> tensor<1x2xi32> + + // CHECK: return [[RES]] + return %res : tensor<1x2xi32> +} + +// CHECK-LABEL: @gather_is_slice_multiple_start_indices +func @gather_is_slice_multiple_start_indices(%arg0: tensor<2x1x2xi32>, %arg1: tensor<2xi64>) -> tensor<1x2xi32> { + // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> + // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]]) + // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]]) + // CHECK-DAG: [[DSLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE1]], [[RESHAPE2]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>} + // CHECK-DAG: [[RES:%.+]] = "mhlo.reshape"([[DSLICE]]) + %res = "mhlo.gather"(%arg0, %arg1) { + dimension_numbers = { + collapsed_slice_dims = dense<0> : tensor<1xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1]> : tensor<2xi64>, + start_index_map = dense<0> : tensor<1xi64> + }, + slice_sizes = dense<[1, 1, 2]> : tensor<3xi64> + } : (tensor<2x1x2xi32>, tensor<2xi64>) -> tensor<1x2xi32> + + // CHECK: return [[RES]] + return %res : tensor<1x2xi32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/sink-constants-to-control-flow.mlir b/tensorflow/compiler/mlir/hlo/tests/sink-constants-to-control-flow.mlir index f8b6b629c9e..9e18ad8a2d8 100644 --- a/tensorflow/compiler/mlir/hlo/tests/sink-constants-to-control-flow.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/sink-constants-to-control-flow.mlir @@ -58,3 +58,17 @@ func @sink_const_to_conditional(%arg0: tensor) -> tensor { %9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor return %9 : tensor } + +func @sink_const_to_sort(%arg0: tensor<16xf32>) { + %c0 = constant dense<1.0> : tensor + // CHECK: "mhlo.sort" + %0 = "mhlo.sort"(%arg0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): + // CHECK: constant dense<1.000000e+00> + %1 = "mhlo.divide"(%arg1, %c0) : (tensor, tensor) -> tensor + %2 = "mhlo.divide"(%arg2, %c0) : (tensor, tensor) -> tensor + %3 = "mhlo.compare"(%1, %2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "mhlo.return"(%3) : (tensor) -> () + }) {is_stable = true} : (tensor<16xf32>) -> tensor<16xi32> + return +} diff --git a/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir b/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir index c1930721218..f903dbb7080 100644 --- a/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/unfuse_batch_norm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -split-input-file -mhlo-test-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s +// RUN: mlir-hlo-opt -split-input-file -mhlo-test-unfuse-batch-norm -verify-diagnostics %s | FILECHECK_OPTS="" FileCheck --enable-var-scope %s // CHECK-LABEL: @batchNormInference_2D_inner_features // CHECK-SAME: %[[X:[^:[:space:]]+]] diff --git a/tensorflow/compiler/mlir/hlo/tools/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/tools/CMakeLists.txt new file mode 100644 index 00000000000..0f3d1c85795 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tools/CMakeLists.txt @@ -0,0 +1,16 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +add_subdirectory(mlir-hlo-opt) diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt new file mode 100644 index 00000000000..754469a3c84 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt @@ -0,0 +1,32 @@ +# +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +set(LIBS + ${dialect_libs} + ${conversion_libs} + MLIROptLib + + MhloRegisterDialects + AllMhloPasses + ) +add_llvm_executable(mlir-hlo-opt mlir-hlo-opt.cpp + DEPENDS + MLIRLmhloPassIncGen + MLIRMhloPassIncGen +) +llvm_update_compile_flags(mlir-hlo-opt) +target_link_libraries(mlir-hlo-opt PRIVATE ${LIBS}) diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp new file mode 100644 index 00000000000..70fc21d6959 --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp @@ -0,0 +1,121 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir-hlo/Dialect/mhlo/IR/register.h" +#include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/MlirOptMain.h" + +// NOLINTNEXTLINE +static llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt splitInputFile( + "split-input-file", + llvm::cl::desc("Split the input file into pieces and process each " + "chunk independently"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt verifyDiagnostics( + "verify-diagnostics", + llvm::cl::desc("Check that emitted diagnostics match " + "expected-* lines on the corresponding line"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt verifyPasses( + "verify-each", + llvm::cl::desc("Run the verifier after each transformation pass"), + llvm::cl::init(true)); + +// NOLINTNEXTLINE +static llvm::cl::opt allowUnregisteredDialects( + "allow-unregistered-dialect", + llvm::cl::desc("Allow operation with no registered dialects"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt showDialects( + "show-dialects", llvm::cl::desc("Print the list of registered dialects"), + llvm::cl::init(false)); + +int main(int argc, char **argv) { + mlir::registerAllDialects(); + mlir::registerAllPasses(); + + mlir::mhlo::registerAllDialects(); + mlir::mhlo::registerAllMhloPasses(); + mlir::lmhlo::registerAllLmhloPasses(); + + llvm::InitLLVM y(argc, argv); + + // Register any pass manager command line options. + mlir::registerPassManagerCLOptions(); + mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); + + // Parse pass names in main to ensure static initialization completed. + llvm::cl::ParseCommandLineOptions(argc, argv, + "MLIR modular optimizer driver\n"); + + if (showDialects) { + mlir::MLIRContext context; + llvm::outs() << "Registered Dialects:\n"; + for (mlir::Dialect *dialect : context.getRegisteredDialects()) { + llvm::outs() << dialect->getNamespace() << "\n"; + } + return 0; + } + + // Set up the input file. + std::string errorMessage; + auto file = mlir::openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + + auto output = mlir::openOutputFile(outputFilename, &errorMessage); + if (!output) { + llvm::errs() << errorMessage << "\n"; + exit(1); + } + + if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, + splitInputFile, verifyDiagnostics, verifyPasses, + allowUnregisteredDialects))) { + return 1; + } + // Keep the output file if the invocation of MlirOptMain was successful. + output->keep(); + return 0; +} diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 8d0c204f434..555c11779f5 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -25,7 +25,6 @@ package_group( filegroup( name = "tensorflow_lite_ops_td_files", srcs = [ - "experimental/tfl_hardware_interfaces.td", "ir/tfl_op_interfaces.td", "ir/tfl_ops.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", @@ -221,18 +220,14 @@ cc_library( ], deps = [ ":tensorflow_lite_ops_inc_gen", - ":validators", - "//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/lite/schema:schema_fbs", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LoopLikeInterface", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:StandardOps", @@ -273,7 +268,9 @@ cc_library( deps = [ ":tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/core:framework", + "@flatbuffers", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:StandardOps", @@ -338,6 +335,7 @@ cc_library( "transforms/optimize_functional_ops.cc", "transforms/prepare_composite_functions_tf.cc", "transforms/prepare_tf.cc", + "transforms/raise_custom_ops.cc", "transforms/runtime_verify.cc", "transforms/split_merged_operands.cc", "transforms/trim_functions_tf.cc", @@ -349,26 +347,29 @@ cc_library( "transforms/passes.h", ], deps = [ - ":common", ":lstm_utils", ":stateful_ops_utils", ":tensorflow_lite", ":tftext_utils", ":validators", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:tf_legalize_hlo", "//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:tensor_list", - "//tensorflow/core/platform:logging", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", @@ -399,7 +400,6 @@ cc_library( "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", @@ -433,7 +433,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", @@ -454,7 +453,7 @@ cc_library( deps = [ ":tensorflow_lite", "//tensorflow/lite/tools/optimize/sparsity:format_converter", - "@com_google_absl//absl/base", + "//third_party/eigen3", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", @@ -480,7 +479,6 @@ gentbl( td_srcs = [ "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", - "experimental/tfl_hardware_interfaces.td", "ir/tfl_op_interfaces.td", ], ) @@ -609,8 +607,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", - "//tensorflow/compiler/mlir/tensorflow:mangling_util", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", @@ -620,7 +616,7 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite:string_util", - "//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib", + "//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/tools/versioning", @@ -651,7 +647,6 @@ cc_library( ":flatbuffer_tflite_operator_lib", ":tensorflow_lite", ":tensorflow_lite_dialect_registration", - "//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", @@ -724,7 +719,6 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirTranslateMain", @@ -858,10 +852,8 @@ cc_library( "//tensorflow/core:core_cpu_base", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD b/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD index 04d5d3db918..373c95f6bf5 100644 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/BUILD @@ -8,9 +8,6 @@ package( cc_library( name = "cost_estimators", textual_hdrs = [ - "estimator.h", - "cpu_estimators.h", - "gpu_estimators.h", "hardware.h", "arithmetic_count_util.h", ], diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h b/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h index 2ca49e4e1e5..782714f5955 100644 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h +++ b/tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h @@ -15,13 +15,17 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_ +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project + // For add/mul/div/sub and other broadcastable ops. class ArithmeticCountUtilHelper { public: static bool GetArithmeticCountForBroadcastableOp(mlir::Operation* op, int64_t* count) { auto output = op->getResult(0); - auto output_type = output.getType().dyn_cast_or_null(); + auto output_type = + output.getType().dyn_cast_or_null(); if (!output_type || !output_type.hasStaticShape()) return false; *count = output_type.getNumElements(); @@ -31,7 +35,8 @@ class ArithmeticCountUtilHelper { static bool GetInputTensorTotalSize(mlir::Operation* op, int64_t* count) { int64_t total_count = 0; for (auto input : op->getOperands()) { - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = + input.getType().dyn_cast_or_null(); if (!input_type || !input_type.hasStaticShape()) { return false; } @@ -43,14 +48,16 @@ class ArithmeticCountUtilHelper { // For conv2d/depthwise_conv/fully_connected ops. // This algorithm actually comes from TOCO tooling_util.cc - static bool GetArithmeticCountForConvAndFullyconnectedOp(Operation* op, + static bool GetArithmeticCountForConvAndFullyconnectedOp(mlir::Operation* op, int64_t* count) { auto weight = op->getOperand(1); - auto weight_type = weight.getType().dyn_cast_or_null(); + auto weight_type = + weight.getType().dyn_cast_or_null(); if (weight_type == nullptr || !weight_type.hasStaticShape()) return false; auto output = op->getResult(0); - auto output_type = output.getType().dyn_cast_or_null(); + auto output_type = + output.getType().dyn_cast_or_null(); if (output_type == nullptr || !output_type.hasStaticShape()) return false; int64_t cols = 1; @@ -63,7 +70,8 @@ class ArithmeticCountUtilHelper { auto bias = op->getOperand(2); if (bias) { - auto bias_type = bias.getType().dyn_cast_or_null(); + auto bias_type = + bias.getType().dyn_cast_or_null(); if (bias_type && bias_type.hasStaticShape()) { *count += bias_type.getNumElements(); } diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h b/tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h deleted file mode 100644 index b47c08c7cb4..00000000000 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h +++ /dev/null @@ -1,149 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_ - -// CPU -constexpr float kCPUArithmeticUnitCost = 1.0; - -// This basically assumes pure load/store. This is just fake data. -constexpr float kCPUCopyUnitCost = 0.5; -constexpr float kCPUDefaultCost = 3.0f; - -// Default values. -constexpr float kCPUDefaultFixedValuedCost = 10000.0; - -// tfl.add -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t count; - if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op, - &count)) - return kCPUArithmeticUnitCost * count; - return kCPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.concatenation -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t count; - if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count)) - return kCPUCopyUnitCost * count; - return kCPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.conv_2d -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t arithmetic_count; - if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( - op, &arithmetic_count)) { - return arithmetic_count * kCPUArithmeticUnitCost; - } - return kCPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.depthwise_conv_2d -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t arithmetic_count; - if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( - op, &arithmetic_count)) { - return arithmetic_count * kCPUArithmeticUnitCost; - } - return kCPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.fully_connected -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t arithmetic_count; - if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( - op, &arithmetic_count)) { - return arithmetic_count * kCPUArithmeticUnitCost; - } - return kCPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.mul -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t count; - if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op, - &count)) - return kCPUArithmeticUnitCost * count; - return kCPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.pack -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t count; - if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count)) - return kCPUCopyUnitCost * count; - return kCPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.reshape -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t count; - if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count)) - return kCPUCopyUnitCost * count; - return kCPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h b/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h deleted file mode 100644 index c4a509945fa..00000000000 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_ - -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/Operation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/experimental/estimators/hardware.h" -#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" - -template -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { - llvm::errs() << "No defined support for op: " - << op->getName().getStringRef().str(); - return false; - } -}; - -// All ops on CPU are supported. -// TODO(karimnosseir): Only allow TFL ops in the "TFL_OP" param. -template -class TFLiteCostEstimator { - public: - // TODO(karimnosseir): Update and use table based method and lookup - // cost from a loadable table ? - static double GetCost(mlir::Operation* op) { return 0.0; } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h b/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h deleted file mode 100644 index 45e8707ef44..00000000000 --- a/tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h +++ /dev/null @@ -1,543 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ - -// GPU -constexpr float kGPUArithmeticUnitCost = 0.2; - -// The copy can be non-consectutive copy. This is just fake data. -constexpr float kGPUCopyUnitCost = 0.2; -constexpr float kGPUDefaultCost = 1.0f; - -// Default values. -constexpr float kGPUDefaultFixedValuedCost = 10000.0; - -// tfl.abs -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.add -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t count; - if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op, - &count)) - return kGPUArithmeticUnitCost * count; - return kGPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.average_pool_2d -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.concatenation -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t count; - if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count)) - return kGPUCopyUnitCost * count; - return kGPUDefaultFixedValuedCost; - } - - // TODO(renjieliu): We probably need to check for dynamic weights. - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.conv_2d -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t arithmetic_count; - if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( - op, &arithmetic_count)) { - return arithmetic_count * kGPUArithmeticUnitCost; - } - return kGPUDefaultFixedValuedCost; - } - - // TODO(renjieliu): We probably need to check for dynamic weights. - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.cos -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.depthwise_conv_2d -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t arithmetic_count; - if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( - op, &arithmetic_count)) { - return arithmetic_count * kGPUArithmeticUnitCost; - } - return kGPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.div -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.exp -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.fully_connected -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t arithmetic_count; - if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp( - op, &arithmetic_count)) { - return arithmetic_count * kGPUArithmeticUnitCost; - } - return kGPUDefaultFixedValuedCost; - } - - // TODO(renjieliu): we need to check for dynamic weights. - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.hard_swish -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.log -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.logistic -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.max_pool_2d -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.mirror_pad -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.maximum -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.custom -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.mean -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - // TODO(renjieiu): check for constraints. - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.minimum -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.mul -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t count; - if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op, - &count)) - return kGPUArithmeticUnitCost * count; - return kGPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.pad -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.pow -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.prelu -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.relu -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.relu6 -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.reshape -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - int64_t count; - if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count)) - return kGPUCopyUnitCost * count; - return kGPUDefaultFixedValuedCost; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.rsqrt -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.sin -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.slice -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.softmax -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.space_to_depth -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.sqrt -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.square -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.squared_difference -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.strided_slice -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.tanh -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.transpose -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -// tfl.transpose_conv -template <> -class TFLiteCostEstimator { - public: - static double GetCost(mlir::Operation* op) { - llvm::errs() << "No defined cost function for op: " - << op->getName().getStringRef().str(); - return 0.0; - } - - static bool IsSupported(mlir::Operation* op) { return true; } -}; - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_ - diff --git a/tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td b/tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td deleted file mode 100644 index 5c3ec6c206c..00000000000 --- a/tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// WARNING: This Interface is experimental, DO NOT USE. - -// This is the Target Hardware operation interfacea definition file -// for TensorFlow Lite. - -#ifndef TFL_TARGET_HARDWARE_OP_INTERFACES -#define TFL_TARGET_HARDWARE_OP_INTERFACES - -def TFL_CpuTargetOp : OpInterface<"CpuOpTargetInterface"> { - let description = [{ - Interface for ops to run on CPU. - }]; - - let methods = [ - InterfaceMethod< - [{Returns the cost of running this op on CPU.}], - // TODO(karimnosseir): Change to return Cost object instead. - "double", "GetOpCost", (ins "mlir::Operation*":$op_to_check), [{ - // TODO(karimnosseir): Consider changing to another way that doesn't - // rely on template param name. - return TFL::TFLiteCostEstimator::GetCost(op_to_check); - }] - >, - InterfaceMethod< - [{Returns whether this op can be run on CPU.}], - "bool", "IsSupported", (ins "mlir::Operation*":$op_to_check), [{ - // TODO(karimnosseir): Consider changing to another way that doesn't - // rely on template param name. - return TFL::TFLiteCostEstimator::IsSupported(op_to_check); - }] - >, - ]; -} - -def TFL_GpuTargetOp : OpInterface<"GpuOpTargetInterface"> { - let description = [{ - Interface for ops to run on GPU. - }]; - - let methods = [ - InterfaceMethod< - [{Returns the cost of running this op on GPU.}], - // TODO(karimnosseir): Change to return Cost object instead. - "double", "GetOpCost", (ins "Operation*":$op_to_check), [{ - // TODO(karimnosseir): Consider changing to another way that doesn't - // rely on template param name. - return TFL::TFLiteCostEstimator::GetCost(op_to_check); - }] - >, - InterfaceMethod< - [{Returns whether this op can be run on GPU.}], - "bool", "IsSupported", (ins "Operation*":$op_to_check), [{ - // TODO(karimnosseir): Consider changing to another way that doesn't - // rely on template param name. - return TFL::TFLiteCostEstimator::IsSupported(op_to_check); - }] - >, - ]; -} - -#endif // TFL_TARGET_HARDWARE_OP_INTERFACES diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index fb20e842a75..89fae87cb25 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -149,6 +149,9 @@ static StatusOr GetTFLiteType(Type type, if (ftype && ftype.isF32()) { return tflite::TensorType_COMPLEX64; } + if (ftype && ftype.isF64()) { + return tflite::TensorType_COMPLEX128; + } return Status(error::INVALID_ARGUMENT, "Unsupported type"); } case mlir::StandardTypes::Integer: { @@ -1193,22 +1196,35 @@ Optional> Translator::BuildSubGraph( if (IsConst(&inst)) continue; // Fetch operand and result tensor indices. - std::vector operands; - operands.reserve(inst.getNumOperands()); - for (auto operand : inst.getOperands()) { - if (operand.getType().isa()) - operands.push_back(kTfLiteOptionalTensor); - else - operands.push_back(tensor_index_map.lookup(operand)); - } std::vector results; results.reserve(inst.getNumOperands()); for (auto result : inst.getResults()) { results.push_back(tensor_index_map.lookup(result)); } + Operation* real_inst = &inst; + // CustomTfOp is just a wrapper around a TF op, we export the custom Op + // not the wrapper, so we fetch the op from the region. + if (auto custom_op = dyn_cast(inst)) { + // If we have custom op with a region, then use the first op in the + // region, if it exists, otherwise just use params for custom op. + if (!custom_op.body().empty()) { + real_inst = &custom_op.body().front().front(); + } else { + module_.emitError( + "Invalid CustomTfOp: Custom TF Op have empty region."); + } + } + std::vector operands; + operands.reserve(real_inst->getNumOperands()); + for (auto operand : real_inst->getOperands()) { + if (operand.getType().isa()) + operands.push_back(kTfLiteOptionalTensor); + else + operands.push_back(tensor_index_map.lookup(operand)); + } if (auto tfl_operator = - BuildOperator(&inst, operands, results, intermediates)) + BuildOperator(real_inst, operands, results, intermediates)) operators.push_back(*tfl_operator); else failed_once = true; @@ -1402,7 +1418,7 @@ BufferOffset Translator::BuildSparsityParameters( } else { auto segments = dim_metadata.segments(); std::vector vector_segments(segments.size(), 0); - for (int j = 0; j < segments.size(); j++) { + for (int j = 0, end = segments.size(); j < end; j++) { vector_segments[j] = segments[j].dyn_cast().getInt(); } tflite::SparseIndexVector segments_type; @@ -1434,7 +1450,7 @@ BufferOffset Translator::BuildSparsityParameters( auto indices = dim_metadata.indices(); std::vector vector_indices(indices.size(), 0); int max_of_indices = 0; - for (int j = 0; j < indices.size(); j++) { + for (int j = 0, end = indices.size(); j < end; j++) { vector_indices[j] = indices[j].dyn_cast().getInt(); if (vector_indices[j] > max_of_indices) { max_of_indices = vector_indices[j]; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index fa85b4e50fd..3c8bf26aa14 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -229,7 +229,7 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, llvm::SmallVector min_maxs; min_maxs.reserve(mins.size() * 2); - for (int i = 0; i < mins.size(); ++i) { + for (int i = 0, end = mins.size(); i < end; ++i) { llvm::APFloat min(mins[i]); llvm::APFloat max(maxs[i]); min_maxs.push_back(min); @@ -281,7 +281,7 @@ std::vector ReadAsLittleEndian(ArrayRef bytes) { int bytes_len = bytes.size(); assert(bytes_len % read_size == 0); - size_t elem_count = bytes_len / read_size; + int elem_count = bytes_len / read_size; ret.reserve(elem_count); const char* data_ptr = reinterpret_cast(bytes.data()); @@ -318,7 +318,7 @@ StatusOr ConvertFloatBuffer( switch (elem_type.getWidth()) { case 16: { assert(bytes_len % 2 == 0); - size_t elem_count = bytes_len / 2; + int elem_count = bytes_len / 2; std::vector values; values.reserve(elem_count); @@ -337,7 +337,7 @@ StatusOr ConvertFloatBuffer( } case 32: { assert(bytes_len % 4 == 0); - size_t elem_count = bytes_len / 4; + int elem_count = bytes_len / 4; std::vector values; values.reserve(elem_count); @@ -353,7 +353,7 @@ StatusOr ConvertFloatBuffer( } case 64: { assert(bytes_len % 8 == 0); - size_t elem_count = bytes_len / 8; + int elem_count = bytes_len / 8; std::vector values; values.reserve(elem_count); @@ -829,7 +829,7 @@ StatusOr ConvertSubgraph( // Add state variables to inputs. absl::flat_hash_set input_index_set(func_inputs.begin(), func_inputs.end()); - for (int i = 0; i < subgraph.tensors.size(); i++) { + for (int i = 0, end = subgraph.tensors.size(); i < end; i++) { auto& tensor = *subgraph.tensors.at(i); if (tensor.is_variable && !input_index_set.contains(i)) { func_inputs.emplace_back(i); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index becc2f7ab85..e14178d6f6d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -19,7 +19,6 @@ limitations under the License. #define TFL_OP_INTERFACES include "mlir/IR/OpBase.td" -include "tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td" //===----------------------------------------------------------------------===// // TFL op interface for stateful operands. diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 427b9c692a7..b5fcd5e82e2 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -269,7 +269,7 @@ struct TensorFlowLiteOpFolderDialectInterface }; TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context) - : Dialect(/*name=*/"tfl", context) { + : Dialect(/*name=*/"tfl", context, TypeID::get()) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" @@ -773,8 +773,8 @@ static LogicalResult Verify(CustomOp op) { op.custom_option().cast(); if (!opaque_attr.getType().hasStaticShape()) return op.emitOpError("custom_option should have a static shape."); - if (opaque_attr.getValue().size() != - opaque_attr.getType().cast().getDimSize(0)) + const int attribute_size = opaque_attr.getValue().size(); + if (attribute_size != opaque_attr.getType().cast().getDimSize(0)) return op.emitOpError( "custom_option should have the same length of content with shape."); return success(); @@ -955,7 +955,7 @@ static LogicalResult Verify(ScatterNdOp op) { // Checks whether the last `(shape_type.getDimSize(0) - outermost_dim)` // dimensions of `updates` and `shape` are equal. for (auto shape_it : llvm::enumerate(shape_value)) { - auto i = shape_it.index(); + int64_t i = shape_it.index(); auto value = shape_it.value().getSExtValue(); if (i >= outermost_dim) { auto corresponding_dim = i - outermost_dim + outer_dims; @@ -1192,7 +1192,8 @@ struct RemoveRedundantUnpackPack : public RewritePattern { return failure(); const int total_pack_inputs = pack_op.getNumOperands(); - if (total_pack_inputs != input_unpack_op.getNumResults()) return failure(); + const int num_results = input_unpack_op.getNumResults(); + if (total_pack_inputs != num_results) return failure(); for (auto input_output : llvm::zip(pack_op.getOperands(), input_unpack_op.getResults())) { Value pack_input = std::get<0>(input_output); @@ -1261,8 +1262,7 @@ static LogicalResult Verify(SliceOp op) { } if (begin && size && input_type.hasStaticShape()) { - const int input_rank = begin.getNumElements(); - for (uint64_t i = 0; i < input_rank; i++) { + for (uint64_t i = 0, end = begin.getNumElements(); i < end; i++) { int begin_i = begin.getValue({i}).cast().getValue().getSExtValue(); int size_i = diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index c7a1504c3b7..caed0bb3ad9 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -29,7 +29,7 @@ limitations under the License. #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/lite/schema/schema_generated.h" namespace mlir { @@ -48,14 +48,9 @@ class TensorFlowLiteDialect : public Dialect { Location loc) override; }; -#include "tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc" #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc" -// Include all specializes estimators below this line -#include "tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h" -#include "tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h" -#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h" } // end namespace TFL } // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 4a56d893b19..6dc9fda656f 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -410,10 +410,7 @@ def TFL_ComparisonBinaryBuilder : OpBuilder< class TFL_Op traits = []> : Op, - // All TFL ops are supported on CPU. - DeclareOpInterfaceMethods - ])> { + [DeclareOpInterfaceMethods])> { // FlatBuffer generation specific information. // ------------------------------------------- // When generating the FlatBuffer output some operations have @@ -435,8 +432,7 @@ class TFL_Op traits = []> : class TFL_ConvOp : TFL_Op, - AffineQuantizedOpInterface, AffineOpCoefficient, - TFL_GpuTargetOp, TFL_SparseOp]> { + AffineQuantizedOpInterface, AffineOpCoefficient, TFL_SparseOp]> { let summary = opSummary # " operator"; let description = [{ @@ -473,8 +469,7 @@ def TFL_AbsOp : TFL_Op<"abs", [ NoSideEffect, SameOperandsAndResultShape, SameOperandsAndResultType, - NoQuantizableResult, - TFL_GpuTargetOp]> { + NoQuantizableResult]> { let summary = "Absolute value operator"; let description = [{ @@ -495,8 +490,7 @@ def TFL_AddOp : TFL_Op<"add", [ CPred<"TFL::VerifyAddOpShapeConstraints(llvm::cast($_op))">>, ResultsBroadcastableShape, NoSideEffect, - Commutative, - TFL_GpuTargetOp]> { + Commutative]> { let summary = "Addition operator"; let description = [{ @@ -573,7 +567,6 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [ TFL_TCresVTEtIsSameAsOp<0, 2>>, AccumulatorUniformScale<3, 1, 2>, AffineQuantizedOpInterface, AffineOpCoefficient<0, 2>, - TFL_GpuTargetOp, TFL_SparseOp]> { let summary = "Transpose convolution operator"; @@ -612,8 +605,7 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [ def TFL_AveragePool2DOp: TFL_Op<"average_pool_2d", [NoSideEffect, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { + SameOperandsAndResultsScale]> { let summary = "Average_pool_2d operator"; let description = [{ @@ -713,8 +705,7 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation", NoSideEffect, PredOpTrait<"values and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - SameOperandsAndResultsScale, - TFL_GpuTargetOp + SameOperandsAndResultsScale ]> { let summary = "Concatenation operator"; @@ -861,8 +852,7 @@ def TFL_CosOp: TFL_Op<"cos", [ NoSideEffect, SameOperandsAndResultShape, SameOperandsAndResultType, - NoQuantizableResult, - TFL_GpuTargetOp]> { + NoQuantizableResult]> { let summary = "Cosine operator"; let description = [{ @@ -916,8 +906,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [ NoSideEffect, AccumulatorUniformScale<2, 0, 1>, AffineQuantizedOpInterface, AffineOpCoefficient<-1, 1>, - TFL_SparseOp, - TFL_GpuTargetOp]> { + TFL_SparseOp]> { let summary = "Fully connected op"; let arguments = (ins @@ -954,7 +943,10 @@ def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [ NoSideEffect, TFL_OperandHasAtleastRank<0, 2>, TFL_OperandHasAtleastRank<1, 2>, - SameOperandsAndResultElementType]> { + PredOpTrait<"x and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 0>>, + PredOpTrait<"y and output must have same element type", + TFL_TCresVTEtIsSameAsOp<0, 1>>]> { let summary = "Batch Matrix Multiply Operator"; @@ -1070,8 +1062,7 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [ ResultsBroadcastableShape, BinaryOpSameElementTypeConstraint, TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, - NoSideEffect, - NoQuantizableResult]> { + NoSideEffect]> { let summary = "Less_equal operator"; let description = [{ @@ -1132,8 +1123,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [ TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, ResultsBroadcastableShape, - NoSideEffect, - NoQuantizableResult]> { + NoSideEffect]> { let summary = "Greater_equal operator"; let description = [{ @@ -1360,8 +1350,7 @@ def TFL_DivOp : TFL_Op<"div", [ TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>, ResultsBroadcastableShape, NoSideEffect, - NoQuantizableResult, - TFL_GpuTargetOp]> { + NoQuantizableResult]> { let summary = "Division operator"; let description = [{ @@ -1427,7 +1416,6 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup", def TFL_EqualOp: TFL_Op<"equal", [ Commutative, - NoQuantizableResult, ResultsBroadcastableShape, TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> { @@ -1449,8 +1437,7 @@ def TFL_EqualOp: TFL_Op<"equal", [ } def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect, - SameOperandsAndResultType, - TFL_GpuTargetOp]> { + SameOperandsAndResultType]> { let summary = "Natural exponentiation operator"; let description = [{ @@ -1634,8 +1621,7 @@ def TFL_GreaterOp : TFL_Op<"greater", [ ResultsBroadcastableShape, BinaryOpSameElementTypeConstraint, TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, - NoSideEffect, - NoQuantizableResult]> { + NoSideEffect]> { let summary = "Greater operator"; let description = [{ @@ -1659,8 +1645,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [ NoSideEffect, SameOperandsAndResultShape, PredOpTrait<"input and output must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 0>>, - TFL_GpuTargetOp]> { + TFL_TCresVTEtIsSameAsOp<0, 0>>]> { let summary = "Hardswish activation function."; let description = [{ Computes hard-swish activation function @@ -1676,12 +1661,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [ } def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect, - FixedOutputRangeInterface, - // central_value = min_value / 2 + (max_value - 1) / 2 + 1 - // zero_point = central_value - // scale = 1. / (central_value - min_value) - FixedResultScale>, - FixedResultScale>]> { + FixedOutputRangeInterface]> { let summary = "L2 Normalize Operator"; let description = [{ @@ -1703,29 +1683,12 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [NoSideEffect, // FixedOutputRangeInterface: quant::UniformQuantizedType GetFixedOutputRange( bool is_signed, int bit_width) { - auto result_type = output().getType().cast(); - if (!result_type.getElementType().isa()) return {}; - Builder builder(result_type.getContext()); - - // Only support 8-bits - if (bit_width != 8) return {}; - IntegerType storage_type = builder.getIntegerType(bit_width); - - double scale = 1.0 / 128; - int64_t zero_point, storage_min, storage_max; - if (is_signed) { - zero_point = 0; - storage_min = -128; - storage_max = 127; - } else { - zero_point = 128; - storage_min = 0; - storage_max = 255; - } - - return quant::UniformQuantizedType::getChecked( - is_signed, storage_type, result_type.getElementType(), scale, - zero_point, storage_min, storage_max, builder.getUnknownLoc()); + auto result_type = output().getType(); + // central_value = min_value / 2 + (max_value - 1) / 2 + 1 + // zero_point = central_value + // scale = 1. / (central_value - min_value) + return quant::GetFixedOutputRange(is_signed, bit_width, result_type, + /*scale=*/1.0 / 128, /*zero_point=*/0); } }]; } @@ -1757,8 +1720,7 @@ def TFL_LessOp : TFL_Op<"less", [ ResultsBroadcastableShape, BinaryOpSameElementTypeConstraint, TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, - NoSideEffect, - NoQuantizableResult]> { + NoSideEffect]> { let summary = "Less operator"; let description = [{ @@ -1834,12 +1796,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ PredOpTrait<"x and y must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultShape, - // zero_point = 0 - // scale = 1. / (max_value + 1) - FixedResultScale>, - FixedResultScale>, - FixedOutputRangeInterface, - TFL_GpuTargetOp]> { + FixedOutputRangeInterface]> { let summary = "Logistic operator"; let description = [{ @@ -1854,29 +1811,11 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ // FixedOutputRangeInterface: quant::UniformQuantizedType GetFixedOutputRange( bool is_signed, int bit_width) { - auto result_type = y().getType().cast(); - if (!result_type.getElementType().isa()) return {}; - Builder builder(result_type.getContext()); - - // Only support 8-bits - if (bit_width != 8) return {}; - IntegerType storage_type = builder.getIntegerType(bit_width); - - double scale = 1.0 / 256; - int64_t zero_point, storage_min, storage_max; - if (is_signed) { - zero_point = -128; - storage_min = -128; - storage_max = 127; - } else { - zero_point = 0; - storage_min = 0; - storage_max = 255; - } - - return quant::UniformQuantizedType::getChecked( - is_signed, storage_type, result_type.getElementType(), scale, - zero_point, storage_min, storage_max, builder.getUnknownLoc()); + auto result_type = y().getType(); + // zero_point = 0 + // scale = 1. / (max_value + 1) + return quant::GetFixedOutputRange(is_signed, bit_width, result_type, + /*scale=*/1.0 / 256, /*zero_point=*/-128); } }]; } @@ -1885,8 +1824,7 @@ def TFL_LogOp: TFL_Op<"log", [ NoSideEffect, SameOperandsAndResultShape, SameOperandsAndResultType, - NoQuantizableResult, - TFL_GpuTargetOp]> { + NoQuantizableResult]> { let summary = "Natural logarithm operator"; let description = [{ @@ -1905,10 +1843,7 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ SameOperandsAndResultShape, PredOpTrait<"x and y must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - // zero_point = max_value - // scale = -log_softmax_output_min / (max_value + 1) - FixedResultScale>, - FixedResultScale>]> { + FixedOutputRangeInterface]> { let summary = "Log softmax operator"; let description = [{ @@ -1922,6 +1857,18 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ let results = (outs TFL_TensorOf<[F32, QUI8, QI8, TFL_Quint8]>:$output); let hasOptions = 1; + + let extraClassDeclaration = [{ + // FixedOutputRangeInterface: + quant::UniformQuantizedType GetFixedOutputRange( + bool is_signed, int bit_width) { + auto result_type = output().getType(); + // zero_point = max_value + // scale = -log_softmax_output_min / (max_value + 1) + return quant::GetFixedOutputRange(is_signed, bit_width, result_type, + /*scale=*/16.0 / 256, /*zero_point=*/127); + } + }]; } // TODO(ashwinm): Revisit the granularity of the PredOpTraits. We could @@ -1943,8 +1890,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [ TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, MaxPoolOperandAndResultConstraints, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { + SameOperandsAndResultsScale]> { let summary = "Max Pool 2D op"; let description = [{ @@ -1976,8 +1922,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [ NoSideEffect, TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>, Commutative, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { + SameOperandsAndResultsScale]> { let summary = "Max operator"; let description = [{ Element-wise max operation. @@ -2000,8 +1945,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [ def TFL_MeanOp : TFL_Op<"mean", [ PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - NoSideEffect, - TFL_GpuTargetOp]> { + NoSideEffect]> { let summary = "Mean operator"; let description = [{ @@ -2079,8 +2023,7 @@ def TFL_SliceOp : TFL_Op<"slice", [ SameOperandsAndResultsScale, TFL_OperandHasRankAtMost<0, 4>, TFL_OperandHasRankAtMost<1, 1>, - TFL_OperandHasRankAtMost<2, 1>, - TFL_GpuTargetOp]> { + TFL_OperandHasRankAtMost<2, 1>]> { let summary = "Return a slice from 'input'."; let description = [{ @@ -2211,8 +2154,7 @@ def TFL_MinimumOp : TFL_Op<"minimum", [ NoSideEffect, TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>, Commutative, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { + SameOperandsAndResultsScale]> { let summary = "Min operator"; let description = [{ Element-wise min operation. @@ -2238,8 +2180,7 @@ def TFL_MulOp : TFL_Op<"mul", [ Commutative, BinaryOpSameElementTypeConstraint, TFL_RuntimePredOpTrait<"Operands do not have valid shapes", - CPred<"TFL::VerifyMulOpShapeConstraints(llvm::cast($_op))">>, - TFL_GpuTargetOp]> { + CPred<"TFL::VerifyMulOpShapeConstraints(llvm::cast($_op))">>]> { let summary = "Multiplication operator"; let description = [{ @@ -2345,8 +2286,7 @@ def TFL_PadOp : TFL_Op<"pad", [ TFL_OperandRankEquals1DimOfOperand<0, 1>, PredOpTrait<"the first dim size of the padding argument must be at most 4", Or<[TFL_OperandIsUnrankedPred<1>, - TFL_OperandDimIsAtMost<1, 0, 4>]>>, - TFL_GpuTargetOp]> { + TFL_OperandDimIsAtMost<1, 0, 4>]>>]> { let summary = "Padding operator"; let description = [{ @@ -2439,8 +2379,7 @@ def TFL_PowOp : TFL_Op<"pow", [ ResultsBroadcastableShape, NoSideEffect, TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, - NoQuantizableResult, - TFL_GpuTargetOp]> { + NoQuantizableResult]> { let summary = "Power operator"; let description = [{ @@ -2463,7 +2402,6 @@ def TFL_PowOp : TFL_Op<"pow", [ def TFL_PReluOp : TFL_Op<"prelu", [ NoSideEffect, ResultsBroadcastableShape, - TFL_GpuTargetOp, TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>, BinaryOpSameElementTypeConstraint, PredOpTrait<"input and output must have the same element type", @@ -2505,8 +2443,7 @@ def TFL_ReluOp: TFL_Op<"relu", [ TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { + SameOperandsAndResultsScale]> { let summary = "Relu operator"; let description = [{ @@ -2535,8 +2472,7 @@ def TFL_Relu6Op: TFL_Op<"relu6", [ TFL_TCresVTEtIsSameAsOp<0, 0>>, NoSideEffect, SameOperandsAndResultShape, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { + SameOperandsAndResultsScale]> { let summary = "Relu6 operator"; let description = [{ @@ -2590,7 +2526,7 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [ } def TFL_ReshapeOp: TFL_Op<"reshape", [ - NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> { + NoSideEffect, SameOperandsAndResultsScale]> { let summary = "Reshape operator"; let description = [{ @@ -2645,8 +2581,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType, SameOperandsAndResultShape, - NoQuantizableResult, - TFL_GpuTargetOp]> { + NoQuantizableResult]> { let summary = "Reciprocal of square root operator"; let description = [{ @@ -2741,6 +2676,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [ // are unranked. Therefore, we skip adding shape constraints here. def TFL_SelectOp : TFL_Op<"select", [ NoSideEffect, + SameOperandsAndResultsScale, PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>, PredOpTrait<"operands and result have same element type", TFL_TCresVTEtIsSameAsOp<0, 1>>]> { @@ -2812,8 +2748,7 @@ def TFL_SinOp: TFL_Op<"sin", [ NoSideEffect, SameOperandsAndResultShape, SameOperandsAndResultType, - NoQuantizableResult, - TFL_GpuTargetOp]> { + NoQuantizableResult]> { let summary = "Sine operator"; let description = [{ @@ -2833,11 +2768,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [ TFL_TCresVTEtIsSameAsOp<0, 0>>, TFL_OperandHasRankRange<0, 1, 4>, SameOperandsAndResultShape, - // zero_point = 0 - // scale = 1. / (max_value + 1) - FixedResultScale>, - FixedResultScale>, - TFL_GpuTargetOp]> { + FixedOutputRangeInterface]> { let summary = "Softmax operator"; let description = [{ @@ -2854,14 +2785,25 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [ let results = (outs TFL_TensorOf<[F32, QI8, QUI8, TFL_Quint8]>:$output); let hasOptions = 1; + + let extraClassDeclaration = [{ + // FixedOutputRangeInterface: + quant::UniformQuantizedType GetFixedOutputRange( + bool is_signed, int bit_width) { + auto result_type = output().getType(); + // zero_point = 0 + // scale = 1. / (max_value + 1) + return quant::GetFixedOutputRange(is_signed, bit_width, result_type, + /*scale=*/1.0 / 256, /*zero_point=*/-128); + } + }]; } def TFL_SqrtOp: TFL_Op<"sqrt", [ NoSideEffect, SameOperandsAndResultShape, SameOperandsAndResultType, - NoQuantizableResult, - TFL_GpuTargetOp]> { + NoQuantizableResult]> { let summary = "Square root operator"; let description = [{ @@ -2879,8 +2821,7 @@ def TFL_SquareOp: TFL_Op<"square", [ NoSideEffect, SameOperandsAndResultShape, SameOperandsAndResultType, - NoQuantizableResult, - TFL_GpuTargetOp]> { + NoQuantizableResult]> { let summary = "Square operator"; let description = [{ @@ -2933,8 +2874,7 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [ SameOperandsAndResultElementType, ResultsBroadcastableShape, NoSideEffect, - NoQuantizableResult, - TFL_GpuTargetOp]> { + NoQuantizableResult]> { let summary = "Squared difference operator"; let description = [{ @@ -2959,12 +2899,7 @@ def TFL_TanhOp: TFL_Op<"tanh", [ SameOperandsAndResultShape, PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - // central_value = min_value / 2 + (max_value - 1) / 2 + 1 - // zero_point = central_value - // scale = 1. / (central_value - min_value) - FixedResultScale>, - FixedResultScale>, - TFL_GpuTargetOp]> { + FixedOutputRangeInterface]> { let summary = "Hyperbolic tangent operator"; let description = [{ @@ -2985,6 +2920,19 @@ def TFL_TanhOp: TFL_Op<"tanh", [ state.addTypes(input.getType()); }]> ]; + + let extraClassDeclaration = [{ + // FixedOutputRangeInterface: + quant::UniformQuantizedType GetFixedOutputRange( + bool is_signed, int bit_width) { + auto result_type = output().getType(); + // central_value = min_value / 2 + (max_value - 1) / 2 + 1 + // zero_point = central_value + // scale = 1. / (central_value - min_value) + return quant::GetFixedOutputRange(is_signed, bit_width, result_type, + /*scale=*/1.0 / 128, /*zero_point=*/0); + } + }]; } def TFL_TileOp: TFL_Op<"tile", [ @@ -3052,8 +3000,7 @@ def TFL_TransposeOp : TFL_Op<"transpose", [ TFL_OperandHasRank<1, 1>, PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - SameOperandsAndResultsScale, - TFL_GpuTargetOp]> { + SameOperandsAndResultsScale]> { let summary = "Transpose operator"; let description = [{ @@ -3187,8 +3134,7 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [ SameOperandsAndResultsScale, PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, - TFL_OperandHasRankAtMost<0, 4>, - TFL_GpuTargetOp + TFL_OperandHasRankAtMost<0, 4> ]> { let summary = "SpaceToDepth operator"; @@ -3400,8 +3346,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ TFL_OperandHasRankAtMost<0, 5>, TFL_OperandHasRank<1, 1>, TFL_OperandHasRank<2, 1>, - TFL_OperandHasRank<3, 1>, - TFL_GpuTargetOp + TFL_OperandHasRank<3, 1> ]> { let summary = "StridedSlice Op"; @@ -3451,7 +3396,7 @@ def TFL_CastOp : TFL_Op<"cast", [ } def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [ - NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> { + NoSideEffect, TFL_OperandHasRank<1, 2>]> { let summary = "MirrorPad Operator. Pads a tensor with mirrored values."; let description = [{ @@ -4354,7 +4299,8 @@ def TFL_WhileOp : Op { +def TFL_CustomOp : Op { let summary = "Custom op"; let description = [{ @@ -4377,4 +4323,29 @@ def TFL_CustomOp : Op { let verifier = [{ return Verify(*this); }]; } +def TFL_CustomTfOp : Op]> { + let summary = "Wrapper Op for TF custom ops."; + + let description = [{ + A wrapper op around any Custom TF op. These includes ops defined using + custom_opdefs or linked which are not defined in TF dialect. + This Op just wraps the custom op inside a region. + Note #1, this Op will not include TF Lite custom ops defined using CustomOp. + Note #2, this op is just internal representation inside the converter and + are not exposed/exported when the model is exported to Flatbuffer. + }]; + + let arguments = (ins + Variadic>:$input + ); + let results = (outs Variadic:$output); + + let regions = (region SizedRegion<1>:$body); +} + #endif // TFL_OPS diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index ddd36fbd74c..529c9ee9238 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -75,7 +75,8 @@ Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags, } auto input_names = input_attr.cast().getValue(); input_names.split(function_input_names, ","); - if (function_input_names.size() != model_flags.input_arrays().size()) { + const int function_input_names_size = function_input_names.size(); + if (function_input_names_size != model_flags.input_arrays().size()) { return errors::InvalidArgument( "input array size mismatch: got ", function_input_names.size(), ", expected: ", model_flags.input_arrays().size()); @@ -99,7 +100,8 @@ Status HandleInputOutputArraysWithModule(const toco::ModelFlags& model_flags, } auto output_names = output_attr.cast().getValue(); output_names.split(function_output_names, ","); - if (function_output_names.size() != model_flags.output_arrays().size()) { + const int function_output_names_size = function_output_names.size(); + if (function_output_names_size != model_flags.output_arrays().size()) { return errors::InvalidArgument( "output array size mismatch: got ", function_output_names.size(), ", expected: ", model_flags.output_arrays().size()); @@ -151,10 +153,13 @@ Status ConvertSavedModelToTFLiteFlatBuffer( return errors::Unimplemented("Only support a single exported name."); } + tensorflow::GraphImportConfig specs; + specs.upgrade_legacy = true; + TF_ASSIGN_OR_RETURN(auto module, ImportSavedModel(model_flags.saved_model_dir(), model_flags.saved_model_version(), tags, - exported_names, &context)); + exported_names, specs, &context)); if (!model_flags.input_arrays().empty() || !model_flags.output_arrays().empty()) { diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 4725eb1ac5f..a4e58123e05 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -123,6 +123,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) { return DT_BOOL; case toco::IODataType::COMPLEX64: return DT_COMPLEX64; + case toco::IODataType::COMPLEX128: + return DT_COMPLEX128; default: return DT_INVALID; } diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index de83a37b82e..aec0d8da34f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -81,7 +81,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 0c9ccf1a979..9e0ad990657 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -794,16 +794,18 @@ bool QuantizationDriver::PropagateParams() { } // TODO(fengliuai): make the bit width configurable. - auto spec = GetQuantSpec(op); - auto key = std::make_pair(8, is_signed_); - auto &restricted_outputs = spec->restricted_output_params[key]; - for (int i = 0, e = restricted_outputs.size(); i != e; ++i) { - // The restrict can be nullptr if the result has been quantized. - if (auto params = restricted_outputs[i]) { - changed |= SetResultParams(op, i, params); + if (auto restricted = llvm::dyn_cast(op)) { + // TODO(fengliuai): different result can have different fixed range. + auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8); + for (auto i = 0; i < op->getNumResults(); ++i) { + // The range is null if the result has been quantized. + if (params) { + changed |= SetResultParams(op, i, params); + } } } + auto spec = GetQuantSpec(op); for (auto &it : spec->biases_params) { auto params = GetBiasParams(op, it.first, it.second.first, it.second.second); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index 8f6b63b3ee6..9991d103449 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -449,7 +449,7 @@ static bool PreferResultScale(Operation* op) { // only considers the ops with restricted output params. static bool IsStatsRedundant(Operation* op, OpQuantSpecGetter op_quant_spec_getter) { - return !op_quant_spec_getter(op)->restricted_output_params.empty(); + return llvm::isa(op); } bool RemoveRedundantStatsOps(mlir::FuncOp func, @@ -469,7 +469,7 @@ bool RemoveRedundantStatsOps(mlir::FuncOp func, // Step 1: forward pass: propagate any value scales which are not produces // by `SameOperandsAndResultsScale`. Additionally, remove the value scales - // which are produced by the `restricted_output_params`. + // which are produced by the ops with the `FixedOutputRangeInterface`. // Note that we don't propagate across the multiple-operands // `SameOperandsAndResultsScale` ops like `concatenation`. func.walk( @@ -594,5 +594,27 @@ LogicalResult VerifySameScales(Operation* op) { } return success(); } + +quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, + Type tensor_type, double scale, + int64_t zero_point, + int64_t storage_min, + int64_t storage_max) { + auto result_type = tensor_type.cast(); + if (!result_type.getElementType().isa()) return {}; + Builder builder(result_type.getContext()); + + // Only support 8-bits + if (bit_width != 8) return {}; + IntegerType storage_type = builder.getIntegerType(bit_width); + if (!is_signed) { + zero_point += 128; + storage_min += 128; + storage_max += 128; + } + return quant::UniformQuantizedType::getChecked( + is_signed, storage_type, result_type.getElementType(), scale, zero_point, + storage_min, storage_max, builder.getUnknownLoc()); +} } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 4ced43014f5..07e5ba4e879 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -395,8 +395,6 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern { llvm::SmallVector new_output_types; for (auto result : def->getResults()) { - result.getUsers().begin()->dump(); - op.dump(); if (result.hasOneUse() && *result.getUsers().begin() == op) { new_output_types.push_back(op.qtype()); } else { @@ -502,6 +500,13 @@ void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, bool RemoveRedundantStatsOps(mlir::FuncOp func, OpQuantSpecGetter op_quant_spec_getter); +// Given quantization parameters for int8, compute the quantization parameters +// for uint if it is required, and wrap the result in an UniformQuantizedType. +quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, + Type tensor_type, double scale, + int64_t zero_point, + int64_t storage_min = -128, + int64_t storage_max = 127); } // namespace quant } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index b37fdb9aa7b..ff7c47fb621 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -canonicalize | FileCheck %s +// RUN: tf-opt %s -canonicalize | FILECHECK_OPTS="" FileCheck %s // CHECK-LABEL: @add_float func @add_float() -> (tensor, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) { diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt index 345468e609e..481be9d4deb 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/custom_opdef.pbtxt @@ -36,11 +36,11 @@ versions { producer: 27 } -# CHECK-LABEL: func @main -# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32> -# CHECK-SAME: control_outputs = "" -# CHECK-SAME: inputs = "input0,input1" -# CHECK-SAME: outputs = "output" -# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32> -# CHECK-NEXT: return %[[OP]] : tensor<*xi32> -# CHECK-NEXT: } +# CHECK-LABEL: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32> +# CHECK: attributes {tf.entry_function = {control_outputs = "", inputs = "input0,input1", outputs = "output"}} { +# CHECK-NEXT: %[[CUSTOM:.*]] = "tfl.custom_tf"(%arg0, %arg1) ( { +# CHECK-NEXT: %[[OUTPUTS:.*]] = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32> +# CHECK-NEXT: "tfl.yield"(%[[OUTPUTS]]) : (tensor<*xi32>) -> () +# CHECK-NEXT: }) : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32> +# CHECK-NEXT: return %[[CUSTOM]] : tensor<*xi32> +# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir index 50fe804f86c..a622c43c2f2 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/constants.mlir @@ -15,6 +15,13 @@ func @complex64() -> tensor<4xcomplex> { return %0 : tensor<4xcomplex> } +func @complex128() -> tensor<4xcomplex> { + // CHECK-LABEL: @complex128 + // CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex> + %0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex> } : () -> tensor<4xcomplex> + return %0 : tensor<4xcomplex> +} + // TODO(b/138847107) this should work but doesn't // func @f16() -> tensor<4xf16> { // %0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf16> } : () -> tensor<4xf16> diff --git a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir index 1a3ed0509c4..f6f32e7a069 100644 --- a/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir +++ b/tensorflow/compiler/mlir/lite/tests/fuse-tftext.mlir @@ -1,3196 +1,3438 @@ -// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s -module { +// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s | FileCheck %s - func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor, tensor) attributes {tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> - %1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64> - %2 = "tf.Const"() {value = dense : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - %4 = "tf.Const"() {value = dense<[[0], [1]]> : tensor<2x1xi64>} : () -> tensor<2x1xi64> - %5 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> - %6 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %7 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %8 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor - %9 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> - %10 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %11 = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64> - %12 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> - %13 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %14 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %15 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %16 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %17 = "tf.If"(%2, %2, %13, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_3210, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_3200} : (tensor, tensor, tensor, tensor) -> tensor - %18 = "tf.Identity"(%17) {device = ""} : (tensor) -> tensor - %19 = "tf.StringLength"(%arg0) {device = "", unit = "BYTE"} : (tensor<1x!tf.string>) -> tensor<1xi32> - %20 = "tf.ExpandDims"(%19, %7) {device = ""} : (tensor<1xi32>, tensor) -> tensor<1x1xi32> - %21 = "tf.Cast"(%20) {Truncate = false, device = ""} : (tensor<1x1xi32>) -> tensor<1x1xi64> - %22 = "tf.Reshape"(%21, %12) {device = ""} : (tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64> - %23 = "tf.Reshape"(%arg0, %5) {device = ""} : (tensor<1x!tf.string>, tensor<1xi32>) -> tensor<1x!tf.string> - %24:3 = "tf.UnicodeDecodeWithOffsets"(%23) {Tsplits = i64, device = "", errors = "replace", input_encoding = "UTF-8", replace_control_characters = false, replacement_char = 65533 : i64} : (tensor<1x!tf.string>) -> (tensor<2xi64>, tensor, tensor) - %25 = "tf.StridedSlice"(%24#0, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %26 = "tf.AddV2"(%25, %13) {device = ""} : (tensor<1xi64>, tensor) -> tensor<1xi64> - %27 = "tf.StridedSlice"(%24#0, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %28 = "tf.Minimum"(%26, %27) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> - %29:2 = "tf.RaggedRange"(%28, %27, %13) {T = i64, Tsplits = i64, device = ""} : (tensor<1xi64>, tensor<1xi64>, tensor) -> (tensor<2xi64>, tensor) - %30 = "tf.StridedSlice"(%29#0, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %31 = "tf.AddV2"(%30, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> - %32 = "tf.ConcatV2"(%29#0, %31, %14) {device = ""} : (tensor<2xi64>, tensor<1xi64>, tensor) -> tensor<3xi64> - %33 = "tf.GatherV2"(%24#2, %29#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %34 = "tf.ConcatV2"(%33, %22, %14) {device = ""} : (tensor, tensor<1xi64>, tensor) -> tensor - %35:2 = "tf.RaggedGather"(%32, %34, %0) {OUTPUT_RAGGED_RANK = 1 : i64, PARAMS_RAGGED_RANK = 1 : i64, Tindices = i64, Tsplits = i64, Tvalues = i64, device = ""} : (tensor<3xi64>, tensor, tensor<2xi64>) -> (tensor, tensor) - %36:5 = "tf.WhitespaceTokenizeWithOffsets"(%24#1, %24#0) {Tsplits = i64, device = ""} : (tensor, tensor<2xi64>) -> (tensor, tensor, tensor, tensor, tensor) - %37 = "tf.StridedSlice"(%36#1, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %38 = "tf.Equal"(%37, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %39 = "tf.All"(%38, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %40 = "tf.If"(%39, %39, %37, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_3970, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_3960} : (tensor, tensor, tensor, tensor) -> tensor - %41 = "tf.Identity"(%40) {device = ""} : (tensor) -> tensor - %42 = "tf.StridedSlice"(%36#1, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %43 = "tf.StridedSlice"(%36#1, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %44 = "tf.Sub"(%42, %43) {device = ""} : (tensor, tensor) -> tensor - %45 = "tf.LessEqual"(%10, %44) {device = ""} : (tensor, tensor) -> tensor - %46 = "tf.All"(%45, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %47 = "tf.If"(%46, %46, %44) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_4330, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_4320} : (tensor, tensor, tensor) -> tensor - %48 = "tf.Identity"(%47) {device = ""} : (tensor) -> tensor - %49 = "tf.Identity"(%36#1) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %50 = "tf.StridedSlice"(%49, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %51 = "tf.Shape"(%36#0) {device = ""} : (tensor) -> tensor<1xi64> - %52 = "tf.StridedSlice"(%51, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %53 = "tf.Equal"(%50, %52) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %54 = "tf.All"(%53, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %55 = "tf.If"(%54, %54, %50, %52) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_4670, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_4660} : (tensor, tensor, tensor, tensor) -> tensor - %56 = "tf.Identity"(%55) {device = ""} : (tensor) -> tensor - %57 = "tf.Identity"(%49) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %58 = "tf.Shape"(%57) {device = ""} : (tensor) -> tensor<1xi64> - %59 = "tf.StridedSlice"(%58, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %60 = "tf.Sub"(%59, %13) {device = ""} : (tensor, tensor) -> tensor - %61 = "tf.StridedSlice"(%36#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %62 = "tf.Equal"(%61, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %63 = "tf.All"(%62, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %64 = "tf.If"(%63, %63, %61, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_5040, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_5030} : (tensor, tensor, tensor, tensor) -> tensor - %65 = "tf.Identity"(%64) {device = ""} : (tensor) -> tensor - %66 = "tf.StridedSlice"(%36#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %67 = "tf.StridedSlice"(%36#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %68 = "tf.Sub"(%66, %67) {device = ""} : (tensor, tensor) -> tensor - %69 = "tf.LessEqual"(%10, %68) {device = ""} : (tensor, tensor) -> tensor - %70 = "tf.All"(%69, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %71 = "tf.If"(%70, %70, %68) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_5400, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_5390} : (tensor, tensor, tensor) -> tensor - %72 = "tf.Identity"(%71) {device = ""} : (tensor) -> tensor - %73 = "tf.Identity"(%36#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %74 = "tf.StridedSlice"(%73, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %75 = "tf.Equal"(%74, %60) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %76 = "tf.All"(%75, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %77 = "tf.If"(%76, %76, %74, %60) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_5760, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_5750} : (tensor, tensor, tensor, tensor) -> tensor - %78 = "tf.Identity"(%77) {device = ""} : (tensor) -> tensor - %79 = "tf.Identity"(%73) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %80 = "tf.StridedSlice"(%36#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %81 = "tf.Equal"(%80, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %82 = "tf.All"(%81, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %83 = "tf.If"(%82, %82, %80, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6110, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6100} : (tensor, tensor, tensor, tensor) -> tensor - %84 = "tf.Identity"(%83) {device = ""} : (tensor) -> tensor - %85 = "tf.StridedSlice"(%36#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %86 = "tf.StridedSlice"(%36#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %87 = "tf.Sub"(%85, %86) {device = ""} : (tensor, tensor) -> tensor - %88 = "tf.LessEqual"(%10, %87) {device = ""} : (tensor, tensor) -> tensor - %89 = "tf.All"(%88, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %90 = "tf.If"(%89, %89, %87) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_6470, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_6460} : (tensor, tensor, tensor) -> tensor - %91 = "tf.Identity"(%90) {device = ""} : (tensor) -> tensor - %92 = "tf.Identity"(%36#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %93 = "tf.StridedSlice"(%92, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %94 = "tf.Shape"(%36#2) {device = ""} : (tensor) -> tensor<1xi64> - %95 = "tf.StridedSlice"(%94, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %96 = "tf.Equal"(%93, %95) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %97 = "tf.All"(%96, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %98 = "tf.If"(%97, %97, %93, %95) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6810, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6800} : (tensor, tensor, tensor, tensor) -> tensor - %99 = "tf.Identity"(%98) {device = ""} : (tensor) -> tensor - %100 = "tf.Identity"(%92) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %101 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<1xi64> - %102 = "tf.StridedSlice"(%101, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %103 = "tf.Sub"(%102, %13) {device = ""} : (tensor, tensor) -> tensor - %104 = "tf.Equal"(%103, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %105 = "tf.LogicalOr"(%104, %2) {device = ""} : (tensor, tensor) -> tensor - %106 = "tf.Equal"(%103, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %107 = "tf.LogicalOr"(%105, %106) {device = ""} : (tensor, tensor) -> tensor - %108 = "tf.StridedSlice"(%100, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %109 = "tf.StridedSlice"(%100, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %110 = "tf.Sub"(%108, %109) {device = ""} : (tensor, tensor) -> tensor - %111 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<1xi64> - %112 = "tf.StridedSlice"(%111, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %113 = "tf.Sub"(%112, %13) {device = ""} : (tensor, tensor) -> tensor - %114 = "tf.Equal"(%113, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %115 = "tf.ExpandDims"(%100, %7) {device = ""} : (tensor, tensor) -> tensor - %116 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<1xi32> - %117 = "tf.StridedSlice"(%116, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %118 = "tf.StridedSlice"(%116, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %119 = "tf.StridedSlice"(%116, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %120 = "tf.StridedSlice"(%36#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %121 = "tf.Equal"(%120, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %122 = "tf.All"(%121, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %123 = "tf.If"(%122, %122, %120, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7180, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7170} : (tensor, tensor, tensor, tensor) -> tensor - %124 = "tf.Identity"(%123) {device = ""} : (tensor) -> tensor - %125 = "tf.StridedSlice"(%36#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %126 = "tf.StridedSlice"(%36#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %127 = "tf.Sub"(%125, %126) {device = ""} : (tensor, tensor) -> tensor - %128 = "tf.LessEqual"(%10, %127) {device = ""} : (tensor, tensor) -> tensor - %129 = "tf.All"(%128, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %130 = "tf.If"(%129, %129, %127) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_7540, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_7530} : (tensor, tensor, tensor) -> tensor - %131 = "tf.Identity"(%130) {device = ""} : (tensor) -> tensor - %132 = "tf.Identity"(%36#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %133 = "tf.StridedSlice"(%132, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %134 = "tf.Shape"(%36#3) {device = ""} : (tensor) -> tensor<1xi64> - %135 = "tf.StridedSlice"(%134, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %136 = "tf.Equal"(%133, %135) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %137 = "tf.All"(%136, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %138 = "tf.If"(%137, %137, %133, %135) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7880, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7870} : (tensor, tensor, tensor, tensor) -> tensor - %139 = "tf.Identity"(%138) {device = ""} : (tensor) -> tensor - %140 = "tf.Identity"(%132) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %141 = "tf.Shape"(%140) {device = ""} : (tensor) -> tensor<1xi64> - %142 = "tf.StridedSlice"(%141, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %143 = "tf.Sub"(%142, %13) {device = ""} : (tensor, tensor) -> tensor - %144 = "tf.Equal"(%143, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %145 = "tf.LogicalOr"(%144, %2) {device = ""} : (tensor, tensor) -> tensor - %146 = "tf.Equal"(%143, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %147 = "tf.LogicalOr"(%145, %146) {device = ""} : (tensor, tensor) -> tensor - %148 = "tf.StridedSlice"(%140, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %149 = "tf.StridedSlice"(%140, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %150 = "tf.Sub"(%148, %149) {device = ""} : (tensor, tensor) -> tensor - %151 = "tf.Shape"(%140) {device = ""} : (tensor) -> tensor<1xi64> - %152 = "tf.StridedSlice"(%151, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %153 = "tf.Sub"(%152, %13) {device = ""} : (tensor, tensor) -> tensor - %154 = "tf.Equal"(%153, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %155 = "tf.ExpandDims"(%140, %7) {device = ""} : (tensor, tensor) -> tensor - %156 = "tf.Shape"(%140) {device = ""} : (tensor) -> tensor<1xi32> - %157 = "tf.StridedSlice"(%156, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %158 = "tf.StridedSlice"(%156, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %159 = "tf.StridedSlice"(%156, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %160 = "tf.StridedSlice"(%140, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %161 = "tf.Range"(%10, %160, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %162 = "tf.StridedSlice"(%140, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %163 = "tf.StridedSlice"(%140, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %164 = "tf.Sub"(%162, %163) {device = ""} : (tensor, tensor) -> tensor - %165 = "tf.If"(%107, %107, %13, %103) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_8680, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_8670} : (tensor, tensor, tensor, tensor) -> tensor - %166 = "tf.Identity"(%165) {device = ""} : (tensor) -> tensor - %167 = "tf.Equal"(%103, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %168 = "tf.Select"(%167, %13, %103) {device = ""} : (tensor, tensor, tensor) -> tensor - %169 = "tf.Equal"(%168, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %170 = "tf.LogicalOr"(%169, %2) {device = ""} : (tensor, tensor) -> tensor - %171 = "tf.Equal"(%168, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %172 = "tf.LogicalOr"(%170, %171) {device = ""} : (tensor, tensor) -> tensor - %173 = "tf.Select"(%114, %168, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %174 = "tf.Pack"(%173, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %175 = "tf.StridedSlice"(%174, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %176 = "tf.Cast"(%175) {Truncate = false, device = ""} : (tensor) -> tensor - %177 = "tf.Reshape"(%176, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %178 = "tf.Pack"(%7, %177) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %179 = "tf.Tile"(%115, %178) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %180 = "tf.Mul"(%177, %118) {device = ""} : (tensor, tensor) -> tensor - %181 = "tf.Pack"(%180) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %182 = "tf.ConcatV2"(%117, %181, %119, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %183 = "tf.Reshape"(%179, %182) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %184 = "tf.Shape"(%183) {device = ""} : (tensor) -> tensor<1xi64> - %185 = "tf.StridedSlice"(%184, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %186 = "tf.Pack"(%175) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %187 = "tf.StridedSlice"(%183, %186, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %188 = "tf.Sub"(%185, %175) {device = ""} : (tensor, tensor) -> tensor - %189 = "tf.Pack"(%188) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %190 = "tf.StridedSlice"(%183, %11, %189, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %191:2 = "tf.RaggedRange"(%190, %187, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %192 = "tf.Select"(%2, %168, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %193 = "tf.Pack"(%192, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %194 = "tf.StridedSlice"(%193, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %195 = "tf.Cast"(%194) {Truncate = false, device = ""} : (tensor) -> tensor - %196 = "tf.Reshape"(%195, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %197 = "tf.Pack"(%7, %196) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %198 = "tf.Tile"(%4, %197) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %199 = "tf.Mul"(%196, %8) {device = ""} : (tensor, tensor) -> tensor - %200 = "tf.Pack"(%199) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %201 = "tf.ConcatV2"(%9, %200, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %202 = "tf.Reshape"(%198, %201) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %203 = "tf.Shape"(%202) {device = ""} : (tensor) -> tensor<1xi64> - %204 = "tf.StridedSlice"(%203, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %205 = "tf.Pack"(%194) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %206 = "tf.StridedSlice"(%202, %205, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %207 = "tf.Sub"(%204, %194) {device = ""} : (tensor, tensor) -> tensor - %208 = "tf.Pack"(%207) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %209 = "tf.StridedSlice"(%202, %11, %208, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %210:2 = "tf.RaggedRange"(%209, %206, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %211 = "tf.StridedSlice"(%193, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %212 = "tf.StridedSlice"(%193, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %213 = "tf.Mul"(%212, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> - %214 = "tf.Tile"(%213, %211) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor - %215 = "tf.Cumsum"(%214, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %216 = "tf.ConcatV2"(%11, %215, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %217 = "tf.StridedSlice"(%216, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %218 = "tf.ExpandDims"(%217, %7) {device = ""} : (tensor, tensor) -> tensor - %219 = "tf.Shape"(%217) {device = ""} : (tensor) -> tensor<1xi32> - %220 = "tf.StridedSlice"(%219, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %221 = "tf.Pack"(%220) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %222 = "tf.StridedSlice"(%216, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %223 = "tf.ExpandDims"(%222, %7) {device = ""} : (tensor, tensor) -> tensor - %224 = "tf.Shape"(%222) {device = ""} : (tensor) -> tensor<1xi32> - %225 = "tf.StridedSlice"(%224, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %226 = "tf.Pack"(%225) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %227 = "tf.Equal"(%103, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %228 = "tf.Select"(%227, %168, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %229 = "tf.Cast"(%228) {Truncate = false, device = ""} : (tensor) -> tensor - %230 = "tf.Reshape"(%229, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %231 = "tf.Pack"(%7, %230) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %232 = "tf.Mul"(%230, %8) {device = ""} : (tensor, tensor) -> tensor - %233 = "tf.Pack"(%232) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %234 = "tf.ConcatV2"(%9, %233, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %235 = "tf.Pack"(%228) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %236 = "tf.Pack"(%10, %103) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %237 = "tf.ExpandDims"(%236, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> - %238 = "tf.Tile"(%237, %231) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %239 = "tf.Reshape"(%238, %234) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %240 = "tf.Shape"(%239) {device = ""} : (tensor) -> tensor<1xi64> - %241 = "tf.StridedSlice"(%240, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %242 = "tf.Sub"(%241, %228) {device = ""} : (tensor, tensor) -> tensor - %243 = "tf.Pack"(%242) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %244 = "tf.StridedSlice"(%239, %11, %243, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %245 = "tf.StridedSlice"(%239, %235, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %246:2 = "tf.RaggedRange"(%244, %245, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %247 = "tf.GatherV2"(%110, %246#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %248 = "tf.Cast"(%247) {Truncate = false, device = ""} : (tensor) -> tensor - %249 = "tf.BroadcastTo"(%248, %221) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %250 = "tf.Max"(%249, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %251 = "tf.Maximum"(%14, %250) {device = ""} : (tensor, tensor) -> tensor - %252 = "tf.Range"(%14, %251, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %253 = "tf.Pack"(%7, %251) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %254 = "tf.Tile"(%218, %253) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %255 = "tf.Shape"(%254) {device = ""} : (tensor) -> tensor<2xi32> - %256 = "tf.StridedSlice"(%255, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %257 = "tf.Prod"(%256, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %258 = "tf.Pack"(%257) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %259 = "tf.Shape"(%254) {device = ""} : (tensor) -> tensor<2xi32> - %260 = "tf.StridedSlice"(%259, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %261 = "tf.Shape"(%254) {device = ""} : (tensor) -> tensor<2xi32> - %262 = "tf.StridedSlice"(%261, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %263 = "tf.ConcatV2"(%260, %258, %262, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %264 = "tf.Reshape"(%254, %263) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %265 = "tf.ExpandDims"(%249, %3) {device = ""} : (tensor, tensor) -> tensor - %266 = "tf.Less"(%252, %265) {device = ""} : (tensor, tensor) -> tensor - %267 = "tf.Reshape"(%266, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %268 = "tf.Where"(%267) {device = ""} : (tensor) -> tensor - %269 = "tf.Squeeze"(%268) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %270 = "tf.GatherV2"(%264, %269, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %271 = "tf.Cast"(%247) {Truncate = false, device = ""} : (tensor) -> tensor - %272 = "tf.BroadcastTo"(%271, %226) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %273 = "tf.Max"(%272, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %274 = "tf.Maximum"(%14, %273) {device = ""} : (tensor, tensor) -> tensor - %275 = "tf.Range"(%14, %274, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %276 = "tf.Pack"(%7, %274) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %277 = "tf.Tile"(%223, %276) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %278 = "tf.Shape"(%277) {device = ""} : (tensor) -> tensor<2xi32> - %279 = "tf.StridedSlice"(%278, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %280 = "tf.Prod"(%279, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %281 = "tf.Pack"(%280) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %282 = "tf.Shape"(%277) {device = ""} : (tensor) -> tensor<2xi32> - %283 = "tf.StridedSlice"(%282, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %284 = "tf.Shape"(%277) {device = ""} : (tensor) -> tensor<2xi32> - %285 = "tf.StridedSlice"(%284, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %286 = "tf.ConcatV2"(%283, %281, %285, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %287 = "tf.Reshape"(%277, %286) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %288 = "tf.ExpandDims"(%272, %3) {device = ""} : (tensor, tensor) -> tensor - %289 = "tf.Less"(%275, %288) {device = ""} : (tensor, tensor) -> tensor - %290 = "tf.Reshape"(%289, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %291 = "tf.Where"(%290) {device = ""} : (tensor) -> tensor - %292 = "tf.Squeeze"(%291) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %293 = "tf.GatherV2"(%287, %292, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %294:2 = "tf.RaggedRange"(%270, %293, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %295 = "tf.If"(%172, %172, %168, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_9750, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_9740} : (tensor, tensor, tensor, tensor) -> tensor - %296 = "tf.Identity"(%295) {device = ""} : (tensor) -> tensor - %297 = "tf.Select"(%2, %168, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %298 = "tf.Pack"(%297) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %299 = "tf.ConcatV2"(%1, %298, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> - %300 = "tf.StridedSlice"(%299, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %301 = "tf.Equal"(%300, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %302 = "tf.StridedSlice"(%299, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %303 = "tf.StridedSlice"(%299, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %304 = "tf.Equal"(%303, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %305 = "tf.If"(%304, %304, %303, %247) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_10240, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_10230} : (tensor, tensor, tensor, tensor) -> tensor - %306 = "tf.Identity"(%305) {device = ""} : (tensor) -> tensor - %307 = "tf.If"(%301, %301, %247, %302) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_10600, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_10590} : (tensor, tensor, tensor, tensor) -> tensor - %308 = "tf.If"(%147, %147, %13, %143) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_Assert_AssertGuard_false_15300, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_Assert_AssertGuard_true_15290} : (tensor, tensor, tensor, tensor) -> tensor - %309 = "tf.Identity"(%308) {device = ""} : (tensor) -> tensor - %310 = "tf.Equal"(%143, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %311 = "tf.Select"(%310, %13, %143) {device = ""} : (tensor, tensor, tensor) -> tensor - %312 = "tf.Equal"(%311, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %313 = "tf.LogicalOr"(%312, %2) {device = ""} : (tensor, tensor) -> tensor - %314 = "tf.Equal"(%311, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %315 = "tf.LogicalOr"(%313, %314) {device = ""} : (tensor, tensor) -> tensor - %316 = "tf.Select"(%154, %311, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %317 = "tf.Pack"(%316, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %318 = "tf.StridedSlice"(%317, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %319 = "tf.Cast"(%318) {Truncate = false, device = ""} : (tensor) -> tensor - %320 = "tf.Reshape"(%319, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %321 = "tf.Pack"(%7, %320) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %322 = "tf.Tile"(%155, %321) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %323 = "tf.Mul"(%320, %158) {device = ""} : (tensor, tensor) -> tensor - %324 = "tf.Pack"(%323) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %325 = "tf.ConcatV2"(%157, %324, %159, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %326 = "tf.Reshape"(%322, %325) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %327 = "tf.Shape"(%326) {device = ""} : (tensor) -> tensor<1xi64> - %328 = "tf.StridedSlice"(%327, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %329 = "tf.Pack"(%318) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %330 = "tf.StridedSlice"(%326, %329, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %331 = "tf.Sub"(%328, %318) {device = ""} : (tensor, tensor) -> tensor - %332 = "tf.Pack"(%331) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %333 = "tf.StridedSlice"(%326, %11, %332, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %334:2 = "tf.RaggedRange"(%333, %330, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %335 = "tf.GatherV2"(%161, %334#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %336 = "tf.StridedSlice"(%317, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %337 = "tf.StridedSlice"(%317, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %338 = "tf.StridedSlice"(%317, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> - %339 = "tf.ConcatV2"(%337, %338, %14) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> - %340 = "tf.StridedSlice"(%317, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %341 = "tf.Mul"(%164, %340) {device = ""} : (tensor, tensor) -> tensor - %342 = "tf.Tile"(%341, %336) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %343 = "tf.Cumsum"(%342, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %344 = "tf.ConcatV2"(%11, %343, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %345 = "tf.Shape"(%344) {device = ""} : (tensor) -> tensor<1xi64> - %346 = "tf.StridedSlice"(%345, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %347 = "tf.Sub"(%346, %13) {device = ""} : (tensor, tensor) -> tensor - %348 = "tf.Equal"(%347, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %349 = "tf.LogicalOr"(%348, %2) {device = ""} : (tensor, tensor) -> tensor - %350 = "tf.Equal"(%347, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %351 = "tf.LogicalOr"(%349, %350) {device = ""} : (tensor, tensor) -> tensor - %352 = "tf.StridedSlice"(%344, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %353 = "tf.StridedSlice"(%344, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %354 = "tf.Sub"(%352, %353) {device = ""} : (tensor, tensor) -> tensor - %355 = "tf.Shape"(%344) {device = ""} : (tensor) -> tensor<1xi64> - %356 = "tf.StridedSlice"(%355, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %357 = "tf.Sub"(%356, %13) {device = ""} : (tensor, tensor) -> tensor - %358 = "tf.Equal"(%357, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %359 = "tf.ExpandDims"(%344, %7) {device = ""} : (tensor, tensor) -> tensor - %360 = "tf.Shape"(%344) {device = ""} : (tensor) -> tensor<1xi32> - %361 = "tf.StridedSlice"(%360, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %362 = "tf.StridedSlice"(%360, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %363 = "tf.StridedSlice"(%360, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %364 = "tf.Select"(%2, %311, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %365 = "tf.Pack"(%364, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %366 = "tf.StridedSlice"(%365, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %367 = "tf.Cast"(%366) {Truncate = false, device = ""} : (tensor) -> tensor - %368 = "tf.Reshape"(%367, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %369 = "tf.Pack"(%7, %368) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %370 = "tf.Tile"(%4, %369) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %371 = "tf.Mul"(%368, %8) {device = ""} : (tensor, tensor) -> tensor - %372 = "tf.Pack"(%371) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %373 = "tf.ConcatV2"(%9, %372, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %374 = "tf.Reshape"(%370, %373) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %375 = "tf.Shape"(%374) {device = ""} : (tensor) -> tensor<1xi64> - %376 = "tf.StridedSlice"(%375, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %377 = "tf.Pack"(%366) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %378 = "tf.StridedSlice"(%374, %377, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %379 = "tf.Sub"(%376, %366) {device = ""} : (tensor, tensor) -> tensor - %380 = "tf.Pack"(%379) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %381 = "tf.StridedSlice"(%374, %11, %380, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %382:2 = "tf.RaggedRange"(%381, %378, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %383 = "tf.GatherV2"(%11, %382#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %384 = "tf.GatherV2"(%12, %383, %14) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %385 = "tf.StridedSlice"(%365, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %386 = "tf.StridedSlice"(%365, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %387 = "tf.StridedSlice"(%365, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> - %388 = "tf.ConcatV2"(%386, %387, %14) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> - %389 = "tf.Tile"(%384, %388) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %390 = "tf.StridedSlice"(%365, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %391 = "tf.Mul"(%390, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> - %392 = "tf.Tile"(%391, %385) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor - %393 = "tf.Cumsum"(%392, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %394 = "tf.ConcatV2"(%11, %393, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %395 = "tf.StridedSlice"(%394, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %396 = "tf.ExpandDims"(%395, %7) {device = ""} : (tensor, tensor) -> tensor - %397 = "tf.Shape"(%395) {device = ""} : (tensor) -> tensor<1xi32> - %398 = "tf.StridedSlice"(%397, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %399 = "tf.Pack"(%398) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %400 = "tf.StridedSlice"(%394, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %401 = "tf.ExpandDims"(%400, %7) {device = ""} : (tensor, tensor) -> tensor - %402 = "tf.Shape"(%400) {device = ""} : (tensor) -> tensor<1xi32> - %403 = "tf.StridedSlice"(%402, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %404 = "tf.Pack"(%403) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %405 = "tf.Equal"(%143, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %406 = "tf.Select"(%405, %311, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %407 = "tf.Cast"(%406) {Truncate = false, device = ""} : (tensor) -> tensor - %408 = "tf.Reshape"(%407, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %409 = "tf.Pack"(%7, %408) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %410 = "tf.Mul"(%408, %8) {device = ""} : (tensor, tensor) -> tensor - %411 = "tf.Pack"(%410) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %412 = "tf.ConcatV2"(%9, %411, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %413 = "tf.Pack"(%406) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %414 = "tf.Pack"(%10, %143) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %415 = "tf.ExpandDims"(%414, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> - %416 = "tf.Tile"(%415, %409) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %417 = "tf.Reshape"(%416, %412) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %418 = "tf.Shape"(%417) {device = ""} : (tensor) -> tensor<1xi64> - %419 = "tf.StridedSlice"(%418, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %420 = "tf.Sub"(%419, %406) {device = ""} : (tensor, tensor) -> tensor - %421 = "tf.Pack"(%420) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %422 = "tf.StridedSlice"(%417, %11, %421, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %423 = "tf.StridedSlice"(%417, %413, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %424:2 = "tf.RaggedRange"(%422, %423, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %425 = "tf.GatherV2"(%150, %424#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %426 = "tf.Cast"(%425) {Truncate = false, device = ""} : (tensor) -> tensor - %427 = "tf.BroadcastTo"(%426, %399) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %428 = "tf.Max"(%427, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %429 = "tf.Maximum"(%14, %428) {device = ""} : (tensor, tensor) -> tensor - %430 = "tf.Range"(%14, %429, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %431 = "tf.Pack"(%7, %429) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %432 = "tf.Tile"(%396, %431) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %433 = "tf.Shape"(%432) {device = ""} : (tensor) -> tensor<2xi32> - %434 = "tf.StridedSlice"(%433, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %435 = "tf.Prod"(%434, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %436 = "tf.Pack"(%435) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %437 = "tf.Shape"(%432) {device = ""} : (tensor) -> tensor<2xi32> - %438 = "tf.StridedSlice"(%437, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %439 = "tf.Shape"(%432) {device = ""} : (tensor) -> tensor<2xi32> - %440 = "tf.StridedSlice"(%439, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %441 = "tf.ConcatV2"(%438, %436, %440, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %442 = "tf.Reshape"(%432, %441) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %443 = "tf.ExpandDims"(%427, %3) {device = ""} : (tensor, tensor) -> tensor - %444 = "tf.Less"(%430, %443) {device = ""} : (tensor, tensor) -> tensor - %445 = "tf.Reshape"(%444, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %446 = "tf.Where"(%445) {device = ""} : (tensor) -> tensor - %447 = "tf.Squeeze"(%446) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %448 = "tf.GatherV2"(%442, %447, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %449 = "tf.Cast"(%425) {Truncate = false, device = ""} : (tensor) -> tensor - %450 = "tf.BroadcastTo"(%449, %404) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %451 = "tf.Max"(%450, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %452 = "tf.Maximum"(%14, %451) {device = ""} : (tensor, tensor) -> tensor - %453 = "tf.Range"(%14, %452, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %454 = "tf.Pack"(%7, %452) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %455 = "tf.Tile"(%401, %454) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %456 = "tf.Shape"(%455) {device = ""} : (tensor) -> tensor<2xi32> - %457 = "tf.StridedSlice"(%456, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %458 = "tf.Prod"(%457, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %459 = "tf.Pack"(%458) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %460 = "tf.Shape"(%455) {device = ""} : (tensor) -> tensor<2xi32> - %461 = "tf.StridedSlice"(%460, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %462 = "tf.Shape"(%455) {device = ""} : (tensor) -> tensor<2xi32> - %463 = "tf.StridedSlice"(%462, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %464 = "tf.ConcatV2"(%461, %459, %463, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %465 = "tf.Reshape"(%455, %464) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %466 = "tf.ExpandDims"(%450, %3) {device = ""} : (tensor, tensor) -> tensor - %467 = "tf.Less"(%453, %466) {device = ""} : (tensor, tensor) -> tensor - %468 = "tf.Reshape"(%467, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %469 = "tf.Where"(%468) {device = ""} : (tensor) -> tensor - %470 = "tf.Squeeze"(%469) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %471 = "tf.GatherV2"(%465, %470, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %472:2 = "tf.RaggedRange"(%448, %471, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %473 = "tf.GatherV2"(%389, %472#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %474 = "tf.If"(%315, %315, %311, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_Assert_1_AssertGuard_false_16370, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_Assert_1_AssertGuard_true_16360} : (tensor, tensor, tensor, tensor) -> tensor - %475 = "tf.Identity"(%474) {device = ""} : (tensor) -> tensor - %476 = "tf.Select"(%2, %311, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %477 = "tf.Pack"(%476) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %478 = "tf.ConcatV2"(%1, %477, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> - %479 = "tf.StridedSlice"(%478, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %480 = "tf.Equal"(%479, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %481 = "tf.StridedSlice"(%478, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %482 = "tf.StridedSlice"(%478, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %483 = "tf.Equal"(%482, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %484 = "tf.If"(%483, %483, %482, %425) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_Assert_2_AssertGuard_false_16860, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_Assert_2_AssertGuard_true_16850} : (tensor, tensor, tensor, tensor) -> tensor - %485 = "tf.Identity"(%484) {device = ""} : (tensor) -> tensor - %486 = "tf.If"(%480, %480, %425, %481) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_Assert_3_AssertGuard_false_17220, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_Assert_3_AssertGuard_true_17210} : (tensor, tensor, tensor, tensor) -> tensor - %487 = "tf.Identity"(%486) {device = ""} : (tensor) -> tensor - %488 = "tf.If"(%351, %351, %13, %347) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_21900, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_21890} : (tensor, tensor, tensor, tensor) -> tensor - %489 = "tf.Identity"(%488) {device = ""} : (tensor) -> tensor - %490 = "tf.Equal"(%347, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %491 = "tf.Select"(%490, %13, %347) {device = ""} : (tensor, tensor, tensor) -> tensor - %492 = "tf.Equal"(%491, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %493 = "tf.LogicalOr"(%492, %2) {device = ""} : (tensor, tensor) -> tensor - %494 = "tf.Equal"(%491, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %495 = "tf.LogicalOr"(%493, %494) {device = ""} : (tensor, tensor) -> tensor - %496 = "tf.Select"(%358, %491, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %497 = "tf.Pack"(%496, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %498 = "tf.StridedSlice"(%497, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %499 = "tf.Cast"(%498) {Truncate = false, device = ""} : (tensor) -> tensor - %500 = "tf.Reshape"(%499, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %501 = "tf.Pack"(%7, %500) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %502 = "tf.Tile"(%359, %501) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %503 = "tf.Mul"(%500, %362) {device = ""} : (tensor, tensor) -> tensor - %504 = "tf.Pack"(%503) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %505 = "tf.ConcatV2"(%361, %504, %363, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %506 = "tf.Reshape"(%502, %505) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %507 = "tf.Shape"(%506) {device = ""} : (tensor) -> tensor<1xi64> - %508 = "tf.StridedSlice"(%507, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %509 = "tf.Pack"(%498) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %510 = "tf.StridedSlice"(%506, %509, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %511 = "tf.Sub"(%508, %498) {device = ""} : (tensor, tensor) -> tensor - %512 = "tf.Pack"(%511) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %513 = "tf.StridedSlice"(%506, %11, %512, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %514:2 = "tf.RaggedRange"(%513, %510, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %515 = "tf.Select"(%2, %491, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %516 = "tf.Pack"(%515, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %517 = "tf.StridedSlice"(%516, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %518 = "tf.Cast"(%517) {Truncate = false, device = ""} : (tensor) -> tensor - %519 = "tf.Reshape"(%518, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %520 = "tf.Pack"(%7, %519) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %521 = "tf.Tile"(%4, %520) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %522 = "tf.Mul"(%519, %8) {device = ""} : (tensor, tensor) -> tensor - %523 = "tf.Pack"(%522) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %524 = "tf.ConcatV2"(%9, %523, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %525 = "tf.Reshape"(%521, %524) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %526 = "tf.Shape"(%525) {device = ""} : (tensor) -> tensor<1xi64> - %527 = "tf.StridedSlice"(%526, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %528 = "tf.Pack"(%517) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %529 = "tf.StridedSlice"(%525, %528, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %530 = "tf.Sub"(%527, %517) {device = ""} : (tensor, tensor) -> tensor - %531 = "tf.Pack"(%530) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %532 = "tf.StridedSlice"(%525, %11, %531, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %533:2 = "tf.RaggedRange"(%532, %529, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %534 = "tf.StridedSlice"(%516, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %535 = "tf.StridedSlice"(%516, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %536 = "tf.Mul"(%535, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> - %537 = "tf.Tile"(%536, %534) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor - %538 = "tf.Cumsum"(%537, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %539 = "tf.ConcatV2"(%11, %538, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %540 = "tf.StridedSlice"(%539, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %541 = "tf.ExpandDims"(%540, %7) {device = ""} : (tensor, tensor) -> tensor - %542 = "tf.Shape"(%540) {device = ""} : (tensor) -> tensor<1xi32> - %543 = "tf.StridedSlice"(%542, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %544 = "tf.Pack"(%543) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %545 = "tf.StridedSlice"(%539, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %546 = "tf.ExpandDims"(%545, %7) {device = ""} : (tensor, tensor) -> tensor - %547 = "tf.Shape"(%545) {device = ""} : (tensor) -> tensor<1xi32> - %548 = "tf.StridedSlice"(%547, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %549 = "tf.Pack"(%548) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %550 = "tf.Equal"(%347, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %551 = "tf.Select"(%550, %491, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %552 = "tf.Cast"(%551) {Truncate = false, device = ""} : (tensor) -> tensor - %553 = "tf.Reshape"(%552, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %554 = "tf.Pack"(%7, %553) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %555 = "tf.Mul"(%553, %8) {device = ""} : (tensor, tensor) -> tensor - %556 = "tf.Pack"(%555) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %557 = "tf.ConcatV2"(%9, %556, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %558 = "tf.Pack"(%551) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %559 = "tf.Pack"(%10, %347) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %560 = "tf.ExpandDims"(%559, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> - %561 = "tf.Tile"(%560, %554) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %562 = "tf.Reshape"(%561, %557) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %563 = "tf.Shape"(%562) {device = ""} : (tensor) -> tensor<1xi64> - %564 = "tf.StridedSlice"(%563, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %565 = "tf.Sub"(%564, %551) {device = ""} : (tensor, tensor) -> tensor - %566 = "tf.Pack"(%565) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %567 = "tf.StridedSlice"(%562, %11, %566, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %568 = "tf.StridedSlice"(%562, %558, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %569:2 = "tf.RaggedRange"(%567, %568, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %570 = "tf.GatherV2"(%354, %569#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %571 = "tf.Cast"(%570) {Truncate = false, device = ""} : (tensor) -> tensor - %572 = "tf.BroadcastTo"(%571, %544) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %573 = "tf.Max"(%572, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %574 = "tf.Maximum"(%14, %573) {device = ""} : (tensor, tensor) -> tensor - %575 = "tf.Range"(%14, %574, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %576 = "tf.Pack"(%7, %574) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %577 = "tf.Tile"(%541, %576) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %578 = "tf.Shape"(%577) {device = ""} : (tensor) -> tensor<2xi32> - %579 = "tf.StridedSlice"(%578, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %580 = "tf.Prod"(%579, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %581 = "tf.Pack"(%580) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %582 = "tf.Shape"(%577) {device = ""} : (tensor) -> tensor<2xi32> - %583 = "tf.StridedSlice"(%582, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %584 = "tf.Shape"(%577) {device = ""} : (tensor) -> tensor<2xi32> - %585 = "tf.StridedSlice"(%584, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %586 = "tf.ConcatV2"(%583, %581, %585, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %587 = "tf.Reshape"(%577, %586) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %588 = "tf.ExpandDims"(%572, %3) {device = ""} : (tensor, tensor) -> tensor - %589 = "tf.Less"(%575, %588) {device = ""} : (tensor, tensor) -> tensor - %590 = "tf.Reshape"(%589, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %591 = "tf.Where"(%590) {device = ""} : (tensor) -> tensor - %592 = "tf.Squeeze"(%591) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %593 = "tf.GatherV2"(%587, %592, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %594 = "tf.Cast"(%570) {Truncate = false, device = ""} : (tensor) -> tensor - %595 = "tf.BroadcastTo"(%594, %549) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %596 = "tf.Max"(%595, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %597 = "tf.Maximum"(%14, %596) {device = ""} : (tensor, tensor) -> tensor - %598 = "tf.Range"(%14, %597, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %599 = "tf.Pack"(%7, %597) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %600 = "tf.Tile"(%546, %599) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %601 = "tf.Shape"(%600) {device = ""} : (tensor) -> tensor<2xi32> - %602 = "tf.StridedSlice"(%601, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %603 = "tf.Prod"(%602, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %604 = "tf.Pack"(%603) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %605 = "tf.Shape"(%600) {device = ""} : (tensor) -> tensor<2xi32> - %606 = "tf.StridedSlice"(%605, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %607 = "tf.Shape"(%600) {device = ""} : (tensor) -> tensor<2xi32> - %608 = "tf.StridedSlice"(%607, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %609 = "tf.ConcatV2"(%606, %604, %608, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %610 = "tf.Reshape"(%600, %609) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %611 = "tf.ExpandDims"(%595, %3) {device = ""} : (tensor, tensor) -> tensor - %612 = "tf.Less"(%598, %611) {device = ""} : (tensor, tensor) -> tensor - %613 = "tf.Reshape"(%612, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %614 = "tf.Where"(%613) {device = ""} : (tensor) -> tensor - %615 = "tf.Squeeze"(%614) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %616 = "tf.GatherV2"(%610, %615, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %617:2 = "tf.RaggedRange"(%593, %616, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %618 = "tf.If"(%495, %495, %491, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_22970, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_22960} : (tensor, tensor, tensor, tensor) -> tensor - %619 = "tf.Identity"(%618) {device = ""} : (tensor) -> tensor - %620 = "tf.Select"(%2, %491, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %621 = "tf.Pack"(%620) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %622 = "tf.ConcatV2"(%1, %621, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> - %623 = "tf.StridedSlice"(%622, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %624 = "tf.Equal"(%623, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %625 = "tf.StridedSlice"(%622, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %626 = "tf.StridedSlice"(%622, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %627 = "tf.Equal"(%626, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %628 = "tf.If"(%627, %627, %626, %570) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_23460, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_23450} : (tensor, tensor, tensor, tensor) -> tensor - %629 = "tf.Identity"(%628) {device = ""} : (tensor) -> tensor - %630 = "tf.If"(%624, %624, %570, %625) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_23820, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23810} : (tensor, tensor, tensor, tensor) -> tensor - %631 = "tf.Identity"(%79) {device = ""} : (tensor) -> tensor - %632 = "tf.Identity"(%630) {device = ""} : (tensor) -> tensor - %633 = "tf.Identity"(%307) {device = ""} : (tensor) -> tensor - %634 = "tf.Shape"(%36#2) {device = ""} : (tensor) -> tensor<1xi32> - %635 = "tf.StridedSlice"(%634, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %636 = "tf.Cast"(%635) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> - %637 = "tf.Identity"(%636) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> - %638 = "tf.Shape"(%36#3) {device = ""} : (tensor) -> tensor<1xi32> - %639 = "tf.StridedSlice"(%638, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %640 = "tf.Cast"(%639) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> - %641 = "tf.Identity"(%640) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> - %642 = "tf.GatherV2"(%36#3, %335, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %643 = "tf.Tile"(%642, %339) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %644 = "tf.Sub"(%643, %473) {device = ""} : (tensor, tensor) -> tensor - %645 = "tf.Shape"(%644) {device = ""} : (tensor) -> tensor<1xi32> - %646 = "tf.StridedSlice"(%645, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %647 = "tf.Cast"(%646) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> - %648 = "tf.Identity"(%647) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> - %649 = "tf.UnicodeEncode"(%36#0, %57) {Tsplits = i64, device = "", errors = "replace", output_encoding = "UTF-8", replacement_char = 65533 : i64} : (tensor, tensor) -> tensor - %650 = "tf.Identity"(%649) {device = ""} : (tensor) -> tensor - return %650, %631 : tensor, tensor - } - func @WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_3210(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Input tensors have incompatible shapes."> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedConcat/RaggedFromTensor/Const:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedConcat/RaggedNRows/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_3200(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_3970(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_3960(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_4330(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_4320(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_4670(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_4660(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_5040(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_5030(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_5400(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_5390(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_5760(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RaggedNRows/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_5750(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6110(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6100(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_6470(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_6460(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6810(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6800(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7180(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7170(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_7540(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_7530(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7880(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7870(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_8680(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_8670(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_9750(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_9740(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_10240(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_10230(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_10600(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_10590(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_Assert_AssertGuard_false_15300(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_Assert_AssertGuard_true_15290(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_Assert_1_AssertGuard_false_16370(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_Assert_1_AssertGuard_true_16360(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_Assert_2_AssertGuard_false_16860(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_Assert_2_AssertGuard_true_16850(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_Assert_3_AssertGuard_false_17220(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_Assert_3_AssertGuard_true_17210(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_21900(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_21890(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_22970(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_22960(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_23460(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_23450(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_23820(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23810(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - - // CHECK: func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor, tensor) attributes {tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} { - // CHECK: %0:2 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor, tensor) - // CHECK: return %0#0, %0#1 : tensor, tensor - - func @whitespace_tokenizer_rank2(%arg0: tensor {tf._user_specified_name = "input"}) -> (tensor, tensor, tensor) attributes {tf._input_shapes = [#tf.shape], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64> - %1 = "tf.Const"() {value = dense : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<[[0], [1]]> : tensor<2x1xi64>} : () -> tensor<2x1xi64> - %4 = "tf.Const"() {value = dense<[2, -1]> : tensor<2xi32>} : () -> tensor<2xi32> - %5 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor - %6 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> - %7 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %8 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - %9 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %10 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor - %11 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> - %12 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %13 = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64> - %14 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> - %15 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %16 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %17 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %18 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %19 = "tf.Shape"(%arg0) {device = ""} : (tensor) -> tensor<2xi64> - %20 = "tf.StridedSlice"(%19, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %21 = "tf.StridedSlice"(%19, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %22 = "tf.Mul"(%20, %21) {device = ""} : (tensor, tensor) -> tensor - %23 = "tf.Pack"(%22) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %24 = "tf.StridedSlice"(%19, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> - %25 = "tf.ConcatV2"(%23, %24, %16) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> - %26 = "tf.Reshape"(%arg0, %25) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %27 = "tf.StringLength"(%26) {device = "", unit = "BYTE"} : (tensor) -> tensor - %28 = "tf.ExpandDims"(%27, %9) {device = ""} : (tensor, tensor) -> tensor - %29 = "tf.Cast"(%28) {Truncate = false, device = ""} : (tensor) -> tensor - %30 = "tf.Shape"(%29) {device = ""} : (tensor) -> tensor<2xi64> - %31 = "tf.StridedSlice"(%30, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %32 = "tf.StridedSlice"(%30, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %33 = "tf.Mul"(%31, %32) {device = ""} : (tensor, tensor) -> tensor - %34 = "tf.Pack"(%33) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %35 = "tf.StridedSlice"(%30, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> - %36 = "tf.ConcatV2"(%34, %35, %16) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> - %37 = "tf.Reshape"(%29, %36) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %38 = "tf.StridedSlice"(%30, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %39 = "tf.AddV2"(%38, %15) {device = ""} : (tensor, tensor) -> tensor - %40 = "tf.Range"(%12, %39, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %41 = "tf.Mul"(%40, %15) {device = ""} : (tensor, tensor) -> tensor - %42 = "tf.Reshape"(%26, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %43:3 = "tf.UnicodeDecodeWithOffsets"(%42) {Tsplits = i64, device = "", errors = "replace", input_encoding = "UTF-8", replace_control_characters = false, replacement_char = 65533 : i64} : (tensor) -> (tensor, tensor, tensor) - %44 = "tf.StridedSlice"(%43#0, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %45 = "tf.Shape"(%44) {device = ""} : (tensor) -> tensor<1xi32> - %46 = "tf.ConcatV2"(%45, %18, %16) {device = ""} : (tensor<1xi32>, tensor<1xi32>, tensor) -> tensor<2xi32> - %47 = "tf.Reshape"(%44, %46) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %48 = "tf.Shape"(%47) {device = ""} : (tensor) -> tensor<2xi64> - %49 = "tf.StridedSlice"(%48, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %50 = "tf.AddV2"(%49, %15) {device = ""} : (tensor, tensor) -> tensor - %51 = "tf.Range"(%12, %50, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %52 = "tf.Mul"(%51, %15) {device = ""} : (tensor, tensor) -> tensor - %53 = "tf.ExpandDims"(%52, %9) {device = ""} : (tensor, tensor) -> tensor - %54 = "tf.Shape"(%52) {device = ""} : (tensor) -> tensor<1xi32> - %55 = "tf.StridedSlice"(%54, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %56 = "tf.StridedSlice"(%54, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %57 = "tf.StridedSlice"(%54, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %58 = "tf.StridedSlice"(%52, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %59 = "tf.StridedSlice"(%52, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %60 = "tf.Sub"(%58, %59) {device = ""} : (tensor, tensor) -> tensor - %61 = "tf.Shape"(%47) {device = ""} : (tensor) -> tensor<2xi32> - %62 = "tf.Cast"(%61) {Truncate = false, device = ""} : (tensor<2xi32>) -> tensor<2xi64> - %63 = "tf.StridedSlice"(%62, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %64 = "tf.Equal"(%63, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %65 = "tf.StridedSlice"(%62, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %66 = "tf.Equal"(%65, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %67 = "tf.StridedSlice"(%62, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %68 = "tf.Shape"(%47) {device = ""} : (tensor) -> tensor<2xi32> - %69 = "tf.Cast"(%68) {Truncate = false, device = ""} : (tensor<2xi32>) -> tensor<2xi64> - %70 = "tf.StridedSlice"(%69, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %71 = "tf.Equal"(%70, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %72 = "tf.StridedSlice"(%43#0, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %73 = "tf.AddV2"(%72, %15) {device = ""} : (tensor, tensor) -> tensor - %74 = "tf.StridedSlice"(%43#0, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %75 = "tf.Minimum"(%73, %74) {device = ""} : (tensor, tensor) -> tensor - %76:2 = "tf.RaggedRange"(%75, %74, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %77 = "tf.Shape"(%76#0) {device = ""} : (tensor) -> tensor<1xi64> - %78 = "tf.StridedSlice"(%77, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %79 = "tf.Sub"(%78, %15) {device = ""} : (tensor, tensor) -> tensor - %80 = "tf.Equal"(%38, %79) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %81 = "tf.All"(%80, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %82 = "tf.If"(%81, %81, %38, %79) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_99640, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_99630} : (tensor, tensor, tensor, tensor) -> tensor - %83 = "tf.Identity"(%82) {device = ""} : (tensor) -> tensor - %84 = "tf.StridedSlice"(%41, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %85 = "tf.Mul"(%79, %5) {device = ""} : (tensor, tensor) -> tensor - %86 = "tf.Range"(%12, %85, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %87 = "tf.Reshape"(%86, %4) {device = ""} : (tensor, tensor<2xi32>) -> tensor<2x?xi64> - %88 = "tf.Transpose"(%87, %8) {device = ""} : (tensor<2x?xi64>, tensor<2xi32>) -> tensor - %89 = "tf.Reshape"(%88, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %90 = "tf.StridedSlice"(%76#0, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %91 = "tf.AddV2"(%84, %90) {device = ""} : (tensor, tensor) -> tensor - %92 = "tf.ConcatV2"(%76#0, %91, %16) {device = ""} : (tensor, tensor, tensor) -> tensor - %93 = "tf.GatherV2"(%43#2, %76#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %94 = "tf.ConcatV2"(%93, %37, %16) {device = ""} : (tensor, tensor, tensor) -> tensor - %95:2 = "tf.RaggedGather"(%92, %94, %89) {OUTPUT_RAGGED_RANK = 1 : i64, PARAMS_RAGGED_RANK = 1 : i64, Tindices = i64, Tsplits = i64, Tvalues = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %96 = "tf.StridedSlice"(%95#0, %17, %17, %7) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %97 = "tf.StridedSlice"(%96, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %98 = "tf.Shape"(%97) {device = ""} : (tensor) -> tensor<1xi32> - %99 = "tf.ConcatV2"(%98, %18, %16) {device = ""} : (tensor<1xi32>, tensor<1xi32>, tensor) -> tensor<2xi32> - %100 = "tf.Reshape"(%97, %99) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %101 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<2xi64> - %102 = "tf.StridedSlice"(%101, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %103 = "tf.AddV2"(%102, %15) {device = ""} : (tensor, tensor) -> tensor - %104 = "tf.Range"(%12, %103, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %105 = "tf.Mul"(%104, %15) {device = ""} : (tensor, tensor) -> tensor - %106 = "tf.ExpandDims"(%105, %9) {device = ""} : (tensor, tensor) -> tensor - %107 = "tf.Shape"(%105) {device = ""} : (tensor) -> tensor<1xi32> - %108 = "tf.StridedSlice"(%107, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %109 = "tf.StridedSlice"(%107, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %110 = "tf.StridedSlice"(%107, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %111 = "tf.StridedSlice"(%105, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %112 = "tf.StridedSlice"(%105, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %113 = "tf.Sub"(%111, %112) {device = ""} : (tensor, tensor) -> tensor - %114 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<2xi32> - %115 = "tf.Cast"(%114) {Truncate = false, device = ""} : (tensor<2xi32>) -> tensor<2xi64> - %116 = "tf.StridedSlice"(%115, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %117 = "tf.Equal"(%116, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %118 = "tf.StridedSlice"(%115, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %119 = "tf.Equal"(%118, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %120 = "tf.StridedSlice"(%115, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %121 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<2xi32> - %122 = "tf.Cast"(%121) {Truncate = false, device = ""} : (tensor<2xi32>) -> tensor<2xi64> - %123 = "tf.StridedSlice"(%122, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %124 = "tf.Equal"(%123, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %125:5 = "tf.WhitespaceTokenizeWithOffsets"(%43#1, %43#0) {Tsplits = i64, device = ""} : (tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor) - %126 = "tf.StridedSlice"(%125#1, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %127 = "tf.Equal"(%126, %12) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %128 = "tf.All"(%127, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %129 = "tf.If"(%128, %128, %126, %12) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_100400, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_100390} : (tensor, tensor, tensor, tensor) -> tensor - %130 = "tf.Identity"(%129) {device = ""} : (tensor) -> tensor - %131 = "tf.StridedSlice"(%125#1, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %132 = "tf.StridedSlice"(%125#1, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %133 = "tf.Sub"(%131, %132) {device = ""} : (tensor, tensor) -> tensor - %134 = "tf.LessEqual"(%12, %133) {device = ""} : (tensor, tensor) -> tensor - %135 = "tf.All"(%134, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %136 = "tf.If"(%135, %135, %133) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_100760, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_100750} : (tensor, tensor, tensor) -> tensor - %137 = "tf.Identity"(%136) {device = ""} : (tensor) -> tensor - %138 = "tf.Identity"(%125#1) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %139 = "tf.StridedSlice"(%138, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %140 = "tf.Shape"(%125#0) {device = ""} : (tensor) -> tensor<1xi64> - %141 = "tf.StridedSlice"(%140, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %142 = "tf.Equal"(%139, %141) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %143 = "tf.All"(%142, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %144 = "tf.If"(%143, %143, %139, %141) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_101100, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_101090} : (tensor, tensor, tensor, tensor) -> tensor - %145 = "tf.Identity"(%144) {device = ""} : (tensor) -> tensor - %146 = "tf.Identity"(%138) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %147 = "tf.Shape"(%146) {device = ""} : (tensor) -> tensor<1xi64> - %148 = "tf.StridedSlice"(%147, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %149 = "tf.Sub"(%148, %15) {device = ""} : (tensor, tensor) -> tensor - %150 = "tf.StridedSlice"(%125#4, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %151 = "tf.Equal"(%150, %12) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %152 = "tf.All"(%151, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %153 = "tf.If"(%152, %152, %150, %12) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_101470, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_101460} : (tensor, tensor, tensor, tensor) -> tensor - %154 = "tf.Identity"(%153) {device = ""} : (tensor) -> tensor - %155 = "tf.StridedSlice"(%125#4, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %156 = "tf.StridedSlice"(%125#4, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %157 = "tf.Sub"(%155, %156) {device = ""} : (tensor, tensor) -> tensor - %158 = "tf.LessEqual"(%12, %157) {device = ""} : (tensor, tensor) -> tensor - %159 = "tf.All"(%158, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %160 = "tf.If"(%159, %159, %157) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_101830, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_101820} : (tensor, tensor, tensor) -> tensor - %161 = "tf.Identity"(%160) {device = ""} : (tensor) -> tensor - %162 = "tf.Identity"(%125#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %163 = "tf.StridedSlice"(%162, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %164 = "tf.Equal"(%163, %149) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %165 = "tf.All"(%164, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %166 = "tf.If"(%165, %165, %163, %149) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_102190, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_102180} : (tensor, tensor, tensor, tensor) -> tensor - %167 = "tf.Identity"(%166) {device = ""} : (tensor) -> tensor - %168 = "tf.Identity"(%162) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %169 = "tf.StridedSlice"(%125#4, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %170 = "tf.Equal"(%169, %12) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %171 = "tf.All"(%170, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %172 = "tf.If"(%171, %171, %169, %12) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_102540, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_102530} : (tensor, tensor, tensor, tensor) -> tensor - %173 = "tf.Identity"(%172) {device = ""} : (tensor) -> tensor - %174 = "tf.StridedSlice"(%125#4, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %175 = "tf.StridedSlice"(%125#4, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %176 = "tf.Sub"(%174, %175) {device = ""} : (tensor, tensor) -> tensor - %177 = "tf.LessEqual"(%12, %176) {device = ""} : (tensor, tensor) -> tensor - %178 = "tf.All"(%177, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %179 = "tf.If"(%178, %178, %176) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_102900, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_102890} : (tensor, tensor, tensor) -> tensor - %180 = "tf.Identity"(%179) {device = ""} : (tensor) -> tensor - %181 = "tf.Identity"(%125#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %182 = "tf.StridedSlice"(%181, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %183 = "tf.Shape"(%125#2) {device = ""} : (tensor) -> tensor<1xi64> - %184 = "tf.StridedSlice"(%183, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %185 = "tf.Equal"(%182, %184) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %186 = "tf.All"(%185, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %187 = "tf.If"(%186, %186, %182, %184) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_103240, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_103230} : (tensor, tensor, tensor, tensor) -> tensor - %188 = "tf.Identity"(%187) {device = ""} : (tensor) -> tensor - %189 = "tf.Identity"(%181) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %190 = "tf.Shape"(%189) {device = ""} : (tensor) -> tensor<1xi64> - %191 = "tf.StridedSlice"(%190, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %192 = "tf.Sub"(%191, %15) {device = ""} : (tensor, tensor) -> tensor - %193 = "tf.Equal"(%192, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %194 = "tf.LogicalOr"(%64, %193) {device = ""} : (tensor, tensor) -> tensor - %195 = "tf.Equal"(%192, %63) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %196 = "tf.LogicalOr"(%194, %195) {device = ""} : (tensor, tensor) -> tensor - %197 = "tf.StridedSlice"(%189, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %198 = "tf.StridedSlice"(%189, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %199 = "tf.Sub"(%197, %198) {device = ""} : (tensor, tensor) -> tensor - %200 = "tf.Shape"(%189) {device = ""} : (tensor) -> tensor<1xi64> - %201 = "tf.StridedSlice"(%200, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %202 = "tf.Sub"(%201, %15) {device = ""} : (tensor, tensor) -> tensor - %203 = "tf.Equal"(%202, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %204 = "tf.ExpandDims"(%189, %9) {device = ""} : (tensor, tensor) -> tensor - %205 = "tf.Shape"(%189) {device = ""} : (tensor) -> tensor<1xi32> - %206 = "tf.StridedSlice"(%205, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %207 = "tf.StridedSlice"(%205, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %208 = "tf.StridedSlice"(%205, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %209 = "tf.StridedSlice"(%125#4, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %210 = "tf.Equal"(%209, %12) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %211 = "tf.All"(%210, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %212 = "tf.If"(%211, %211, %209, %12) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_103610, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_103600} : (tensor, tensor, tensor, tensor) -> tensor - %213 = "tf.Identity"(%212) {device = ""} : (tensor) -> tensor - %214 = "tf.StridedSlice"(%125#4, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %215 = "tf.StridedSlice"(%125#4, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %216 = "tf.Sub"(%214, %215) {device = ""} : (tensor, tensor) -> tensor - %217 = "tf.LessEqual"(%12, %216) {device = ""} : (tensor, tensor) -> tensor - %218 = "tf.All"(%217, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %219 = "tf.If"(%218, %218, %216) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_103970, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_103960} : (tensor, tensor, tensor) -> tensor - %220 = "tf.Identity"(%219) {device = ""} : (tensor) -> tensor - %221 = "tf.Identity"(%125#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %222 = "tf.StridedSlice"(%221, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %223 = "tf.Shape"(%125#3) {device = ""} : (tensor) -> tensor<1xi64> - %224 = "tf.StridedSlice"(%223, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %225 = "tf.Equal"(%222, %224) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %226 = "tf.All"(%225, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %227 = "tf.If"(%226, %226, %222, %224) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_104310, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_104300} : (tensor, tensor, tensor, tensor) -> tensor - %228 = "tf.Identity"(%227) {device = ""} : (tensor) -> tensor - %229 = "tf.Identity"(%221) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %230 = "tf.Shape"(%229) {device = ""} : (tensor) -> tensor<1xi64> - %231 = "tf.StridedSlice"(%230, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %232 = "tf.Sub"(%231, %15) {device = ""} : (tensor, tensor) -> tensor - %233 = "tf.Equal"(%232, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %234 = "tf.LogicalOr"(%233, %1) {device = ""} : (tensor, tensor) -> tensor - %235 = "tf.Equal"(%232, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %236 = "tf.LogicalOr"(%234, %235) {device = ""} : (tensor, tensor) -> tensor - %237 = "tf.StridedSlice"(%229, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %238 = "tf.StridedSlice"(%229, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %239 = "tf.Sub"(%237, %238) {device = ""} : (tensor, tensor) -> tensor - %240 = "tf.Shape"(%229) {device = ""} : (tensor) -> tensor<1xi64> - %241 = "tf.StridedSlice"(%240, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %242 = "tf.Sub"(%241, %15) {device = ""} : (tensor, tensor) -> tensor - %243 = "tf.Equal"(%242, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %244 = "tf.ExpandDims"(%229, %9) {device = ""} : (tensor, tensor) -> tensor - %245 = "tf.Shape"(%229) {device = ""} : (tensor) -> tensor<1xi32> - %246 = "tf.StridedSlice"(%245, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %247 = "tf.StridedSlice"(%245, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %248 = "tf.StridedSlice"(%245, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %249 = "tf.StridedSlice"(%229, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %250 = "tf.Range"(%12, %249, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %251 = "tf.StridedSlice"(%229, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %252 = "tf.StridedSlice"(%229, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %253 = "tf.Sub"(%251, %252) {device = ""} : (tensor, tensor) -> tensor - %254 = "tf.If"(%196, %196, %63, %192) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_105110, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_105100} : (tensor, tensor, tensor, tensor) -> tensor - %255 = "tf.Identity"(%254) {device = ""} : (tensor) -> tensor - %256 = "tf.Equal"(%192, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %257 = "tf.Select"(%256, %63, %192) {device = ""} : (tensor, tensor, tensor) -> tensor - %258 = "tf.Equal"(%257, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %259 = "tf.LogicalOr"(%258, %66) {device = ""} : (tensor, tensor) -> tensor - %260 = "tf.Equal"(%65, %257) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %261 = "tf.LogicalOr"(%259, %260) {device = ""} : (tensor, tensor) -> tensor - %262 = "tf.Select"(%203, %257, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %263 = "tf.Pack"(%262, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %264 = "tf.StridedSlice"(%263, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %265 = "tf.Cast"(%264) {Truncate = false, device = ""} : (tensor) -> tensor - %266 = "tf.Reshape"(%265, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %267 = "tf.Pack"(%9, %266) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %268 = "tf.Tile"(%204, %267) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %269 = "tf.Mul"(%266, %207) {device = ""} : (tensor, tensor) -> tensor - %270 = "tf.Pack"(%269) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %271 = "tf.ConcatV2"(%206, %270, %208, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %272 = "tf.Reshape"(%268, %271) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %273 = "tf.Shape"(%272) {device = ""} : (tensor) -> tensor<1xi64> - %274 = "tf.StridedSlice"(%273, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %275 = "tf.Pack"(%264) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %276 = "tf.StridedSlice"(%272, %275, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %277 = "tf.Sub"(%274, %264) {device = ""} : (tensor, tensor) -> tensor - %278 = "tf.Pack"(%277) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %279 = "tf.StridedSlice"(%272, %13, %278, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %280:2 = "tf.RaggedRange"(%279, %276, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %281 = "tf.Select"(%71, %257, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %282 = "tf.Pack"(%281, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %283 = "tf.StridedSlice"(%282, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %284 = "tf.Cast"(%283) {Truncate = false, device = ""} : (tensor) -> tensor - %285 = "tf.Reshape"(%284, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %286 = "tf.Pack"(%9, %285) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %287 = "tf.Tile"(%53, %286) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %288 = "tf.Mul"(%285, %56) {device = ""} : (tensor, tensor) -> tensor - %289 = "tf.Pack"(%288) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %290 = "tf.ConcatV2"(%55, %289, %57, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %291 = "tf.Reshape"(%287, %290) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %292 = "tf.Shape"(%291) {device = ""} : (tensor) -> tensor<1xi64> - %293 = "tf.StridedSlice"(%292, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %294 = "tf.Pack"(%283) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %295 = "tf.StridedSlice"(%291, %294, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %296 = "tf.Sub"(%293, %283) {device = ""} : (tensor, tensor) -> tensor - %297 = "tf.Pack"(%296) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %298 = "tf.StridedSlice"(%291, %13, %297, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %299:2 = "tf.RaggedRange"(%298, %295, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %300 = "tf.StridedSlice"(%282, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %301 = "tf.StridedSlice"(%282, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %302 = "tf.Mul"(%60, %301) {device = ""} : (tensor, tensor) -> tensor - %303 = "tf.Tile"(%302, %300) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %304 = "tf.Cumsum"(%303, %16) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %305 = "tf.ConcatV2"(%13, %304, %2) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %306 = "tf.StridedSlice"(%305, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %307 = "tf.ExpandDims"(%306, %9) {device = ""} : (tensor, tensor) -> tensor - %308 = "tf.Shape"(%306) {device = ""} : (tensor) -> tensor<1xi32> - %309 = "tf.StridedSlice"(%308, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %310 = "tf.Pack"(%309) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %311 = "tf.StridedSlice"(%305, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %312 = "tf.ExpandDims"(%311, %9) {device = ""} : (tensor, tensor) -> tensor - %313 = "tf.Shape"(%311) {device = ""} : (tensor) -> tensor<1xi32> - %314 = "tf.StridedSlice"(%313, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %315 = "tf.Pack"(%314) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %316 = "tf.Equal"(%192, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %317 = "tf.Select"(%316, %257, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %318 = "tf.Cast"(%317) {Truncate = false, device = ""} : (tensor) -> tensor - %319 = "tf.Reshape"(%318, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %320 = "tf.Pack"(%9, %319) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %321 = "tf.Mul"(%319, %10) {device = ""} : (tensor, tensor) -> tensor - %322 = "tf.Pack"(%321) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %323 = "tf.ConcatV2"(%11, %322, %11, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %324 = "tf.Pack"(%317) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %325 = "tf.Pack"(%12, %192) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %326 = "tf.ExpandDims"(%325, %9) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> - %327 = "tf.Tile"(%326, %320) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %328 = "tf.Reshape"(%327, %323) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %329 = "tf.Shape"(%328) {device = ""} : (tensor) -> tensor<1xi64> - %330 = "tf.StridedSlice"(%329, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %331 = "tf.Sub"(%330, %317) {device = ""} : (tensor, tensor) -> tensor - %332 = "tf.Pack"(%331) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %333 = "tf.StridedSlice"(%328, %13, %332, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %334 = "tf.StridedSlice"(%328, %324, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %335:2 = "tf.RaggedRange"(%333, %334, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %336 = "tf.GatherV2"(%199, %335#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %337 = "tf.Cast"(%336) {Truncate = false, device = ""} : (tensor) -> tensor - %338 = "tf.BroadcastTo"(%337, %310) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %339 = "tf.Max"(%338, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %340 = "tf.Maximum"(%16, %339) {device = ""} : (tensor, tensor) -> tensor - %341 = "tf.Range"(%16, %340, %9) {device = ""} : (tensor, tensor, tensor) -> tensor - %342 = "tf.Pack"(%9, %340) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %343 = "tf.Tile"(%307, %342) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %344 = "tf.Shape"(%343) {device = ""} : (tensor) -> tensor<2xi32> - %345 = "tf.StridedSlice"(%344, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %346 = "tf.Prod"(%345, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %347 = "tf.Pack"(%346) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %348 = "tf.Shape"(%343) {device = ""} : (tensor) -> tensor<2xi32> - %349 = "tf.StridedSlice"(%348, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %350 = "tf.Shape"(%343) {device = ""} : (tensor) -> tensor<2xi32> - %351 = "tf.StridedSlice"(%350, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %352 = "tf.ConcatV2"(%349, %347, %351, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %353 = "tf.Reshape"(%343, %352) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %354 = "tf.ExpandDims"(%338, %2) {device = ""} : (tensor, tensor) -> tensor - %355 = "tf.Less"(%341, %354) {device = ""} : (tensor, tensor) -> tensor - %356 = "tf.Reshape"(%355, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %357 = "tf.Where"(%356) {device = ""} : (tensor) -> tensor - %358 = "tf.Squeeze"(%357) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %359 = "tf.GatherV2"(%353, %358, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %360 = "tf.Cast"(%336) {Truncate = false, device = ""} : (tensor) -> tensor - %361 = "tf.BroadcastTo"(%360, %315) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %362 = "tf.Max"(%361, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %363 = "tf.Maximum"(%16, %362) {device = ""} : (tensor, tensor) -> tensor - %364 = "tf.Range"(%16, %363, %9) {device = ""} : (tensor, tensor, tensor) -> tensor - %365 = "tf.Pack"(%9, %363) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %366 = "tf.Tile"(%312, %365) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %367 = "tf.Shape"(%366) {device = ""} : (tensor) -> tensor<2xi32> - %368 = "tf.StridedSlice"(%367, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %369 = "tf.Prod"(%368, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %370 = "tf.Pack"(%369) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %371 = "tf.Shape"(%366) {device = ""} : (tensor) -> tensor<2xi32> - %372 = "tf.StridedSlice"(%371, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %373 = "tf.Shape"(%366) {device = ""} : (tensor) -> tensor<2xi32> - %374 = "tf.StridedSlice"(%373, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %375 = "tf.ConcatV2"(%372, %370, %374, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %376 = "tf.Reshape"(%366, %375) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %377 = "tf.ExpandDims"(%361, %2) {device = ""} : (tensor, tensor) -> tensor - %378 = "tf.Less"(%364, %377) {device = ""} : (tensor, tensor) -> tensor - %379 = "tf.Reshape"(%378, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %380 = "tf.Where"(%379) {device = ""} : (tensor) -> tensor - %381 = "tf.Squeeze"(%380) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %382 = "tf.GatherV2"(%376, %381, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %383:2 = "tf.RaggedRange"(%359, %382, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %384 = "tf.If"(%261, %261, %257, %67) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_106180, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_106170} : (tensor, tensor, tensor, tensor) -> tensor - %385 = "tf.Identity"(%384) {device = ""} : (tensor) -> tensor - %386 = "tf.StridedSlice"(%62, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %387 = "tf.Equal"(%386, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %388 = "tf.Select"(%387, %257, %386) {device = ""} : (tensor, tensor, tensor) -> tensor - %389 = "tf.Pack"(%388) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %390 = "tf.StridedSlice"(%62, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> - %391 = "tf.StridedSlice"(%62, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %392 = "tf.ConcatV2"(%390, %389, %391, %16) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> - %393 = "tf.StridedSlice"(%392, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %394 = "tf.Equal"(%393, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %395 = "tf.StridedSlice"(%392, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %396 = "tf.StridedSlice"(%392, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %397 = "tf.Equal"(%396, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %398 = "tf.If"(%397, %397, %396, %336) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_106670, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_106660} : (tensor, tensor, tensor, tensor) -> tensor - %399 = "tf.Identity"(%398) {device = ""} : (tensor) -> tensor - %400 = "tf.If"(%394, %394, %336, %395) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_107030, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_107020} : (tensor, tensor, tensor, tensor) -> tensor - %401 = "tf.If"(%236, %236, %15, %232) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_false_111870, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_true_111860} : (tensor, tensor, tensor, tensor) -> tensor - %402 = "tf.Identity"(%401) {device = ""} : (tensor) -> tensor - %403 = "tf.Equal"(%232, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %404 = "tf.Select"(%403, %15, %232) {device = ""} : (tensor, tensor, tensor) -> tensor - %405 = "tf.Equal"(%404, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %406 = "tf.LogicalOr"(%405, %1) {device = ""} : (tensor, tensor) -> tensor - %407 = "tf.Equal"(%404, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %408 = "tf.LogicalOr"(%406, %407) {device = ""} : (tensor, tensor) -> tensor - %409 = "tf.Select"(%243, %404, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %410 = "tf.Pack"(%409, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %411 = "tf.StridedSlice"(%410, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %412 = "tf.Cast"(%411) {Truncate = false, device = ""} : (tensor) -> tensor - %413 = "tf.Reshape"(%412, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %414 = "tf.Pack"(%9, %413) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %415 = "tf.Tile"(%244, %414) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %416 = "tf.Mul"(%413, %247) {device = ""} : (tensor, tensor) -> tensor - %417 = "tf.Pack"(%416) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %418 = "tf.ConcatV2"(%246, %417, %248, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %419 = "tf.Reshape"(%415, %418) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %420 = "tf.Shape"(%419) {device = ""} : (tensor) -> tensor<1xi64> - %421 = "tf.StridedSlice"(%420, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %422 = "tf.Pack"(%411) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %423 = "tf.StridedSlice"(%419, %422, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %424 = "tf.Sub"(%421, %411) {device = ""} : (tensor, tensor) -> tensor - %425 = "tf.Pack"(%424) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %426 = "tf.StridedSlice"(%419, %13, %425, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %427:2 = "tf.RaggedRange"(%426, %423, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %428 = "tf.GatherV2"(%250, %427#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %429 = "tf.StridedSlice"(%410, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %430 = "tf.StridedSlice"(%410, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %431 = "tf.StridedSlice"(%410, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> - %432 = "tf.ConcatV2"(%430, %431, %16) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> - %433 = "tf.StridedSlice"(%410, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %434 = "tf.Mul"(%253, %433) {device = ""} : (tensor, tensor) -> tensor - %435 = "tf.Tile"(%434, %429) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %436 = "tf.Cumsum"(%435, %16) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %437 = "tf.ConcatV2"(%13, %436, %2) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %438 = "tf.Shape"(%437) {device = ""} : (tensor) -> tensor<1xi64> - %439 = "tf.StridedSlice"(%438, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %440 = "tf.Sub"(%439, %15) {device = ""} : (tensor, tensor) -> tensor - %441 = "tf.Equal"(%440, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %442 = "tf.LogicalOr"(%117, %441) {device = ""} : (tensor, tensor) -> tensor - %443 = "tf.Equal"(%440, %116) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %444 = "tf.LogicalOr"(%442, %443) {device = ""} : (tensor, tensor) -> tensor - %445 = "tf.StridedSlice"(%437, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %446 = "tf.StridedSlice"(%437, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %447 = "tf.Sub"(%445, %446) {device = ""} : (tensor, tensor) -> tensor - %448 = "tf.Shape"(%437) {device = ""} : (tensor) -> tensor<1xi64> - %449 = "tf.StridedSlice"(%448, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %450 = "tf.Sub"(%449, %15) {device = ""} : (tensor, tensor) -> tensor - %451 = "tf.Equal"(%450, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %452 = "tf.ExpandDims"(%437, %9) {device = ""} : (tensor, tensor) -> tensor - %453 = "tf.Shape"(%437) {device = ""} : (tensor) -> tensor<1xi32> - %454 = "tf.StridedSlice"(%453, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %455 = "tf.StridedSlice"(%453, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %456 = "tf.StridedSlice"(%453, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %457 = "tf.Select"(%1, %404, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %458 = "tf.Pack"(%457, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %459 = "tf.StridedSlice"(%458, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %460 = "tf.Cast"(%459) {Truncate = false, device = ""} : (tensor) -> tensor - %461 = "tf.Reshape"(%460, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %462 = "tf.Pack"(%9, %461) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %463 = "tf.Tile"(%3, %462) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %464 = "tf.Mul"(%461, %10) {device = ""} : (tensor, tensor) -> tensor - %465 = "tf.Pack"(%464) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %466 = "tf.ConcatV2"(%11, %465, %11, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %467 = "tf.Reshape"(%463, %466) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %468 = "tf.Shape"(%467) {device = ""} : (tensor) -> tensor<1xi64> - %469 = "tf.StridedSlice"(%468, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %470 = "tf.Pack"(%459) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %471 = "tf.StridedSlice"(%467, %470, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %472 = "tf.Sub"(%469, %459) {device = ""} : (tensor, tensor) -> tensor - %473 = "tf.Pack"(%472) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %474 = "tf.StridedSlice"(%467, %13, %473, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %475:2 = "tf.RaggedRange"(%474, %471, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %476 = "tf.GatherV2"(%13, %475#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %477 = "tf.GatherV2"(%14, %476, %16) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %478 = "tf.StridedSlice"(%458, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %479 = "tf.StridedSlice"(%458, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %480 = "tf.StridedSlice"(%458, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> - %481 = "tf.ConcatV2"(%479, %480, %16) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> - %482 = "tf.Tile"(%477, %481) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %483 = "tf.StridedSlice"(%458, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %484 = "tf.Mul"(%483, %14) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> - %485 = "tf.Tile"(%484, %478) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor - %486 = "tf.Cumsum"(%485, %16) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %487 = "tf.ConcatV2"(%13, %486, %2) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %488 = "tf.StridedSlice"(%487, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %489 = "tf.ExpandDims"(%488, %9) {device = ""} : (tensor, tensor) -> tensor - %490 = "tf.Shape"(%488) {device = ""} : (tensor) -> tensor<1xi32> - %491 = "tf.StridedSlice"(%490, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %492 = "tf.Pack"(%491) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %493 = "tf.StridedSlice"(%487, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %494 = "tf.ExpandDims"(%493, %9) {device = ""} : (tensor, tensor) -> tensor - %495 = "tf.Shape"(%493) {device = ""} : (tensor) -> tensor<1xi32> - %496 = "tf.StridedSlice"(%495, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %497 = "tf.Pack"(%496) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %498 = "tf.Equal"(%232, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %499 = "tf.Select"(%498, %404, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %500 = "tf.Cast"(%499) {Truncate = false, device = ""} : (tensor) -> tensor - %501 = "tf.Reshape"(%500, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %502 = "tf.Pack"(%9, %501) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %503 = "tf.Mul"(%501, %10) {device = ""} : (tensor, tensor) -> tensor - %504 = "tf.Pack"(%503) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %505 = "tf.ConcatV2"(%11, %504, %11, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %506 = "tf.Pack"(%499) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %507 = "tf.Pack"(%12, %232) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %508 = "tf.ExpandDims"(%507, %9) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> - %509 = "tf.Tile"(%508, %502) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %510 = "tf.Reshape"(%509, %505) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %511 = "tf.Shape"(%510) {device = ""} : (tensor) -> tensor<1xi64> - %512 = "tf.StridedSlice"(%511, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %513 = "tf.Sub"(%512, %499) {device = ""} : (tensor, tensor) -> tensor - %514 = "tf.Pack"(%513) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %515 = "tf.StridedSlice"(%510, %13, %514, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %516 = "tf.StridedSlice"(%510, %506, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %517:2 = "tf.RaggedRange"(%515, %516, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %518 = "tf.GatherV2"(%239, %517#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %519 = "tf.Cast"(%518) {Truncate = false, device = ""} : (tensor) -> tensor - %520 = "tf.BroadcastTo"(%519, %492) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %521 = "tf.Max"(%520, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %522 = "tf.Maximum"(%16, %521) {device = ""} : (tensor, tensor) -> tensor - %523 = "tf.Range"(%16, %522, %9) {device = ""} : (tensor, tensor, tensor) -> tensor - %524 = "tf.Pack"(%9, %522) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %525 = "tf.Tile"(%489, %524) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %526 = "tf.Shape"(%525) {device = ""} : (tensor) -> tensor<2xi32> - %527 = "tf.StridedSlice"(%526, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %528 = "tf.Prod"(%527, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %529 = "tf.Pack"(%528) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %530 = "tf.Shape"(%525) {device = ""} : (tensor) -> tensor<2xi32> - %531 = "tf.StridedSlice"(%530, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %532 = "tf.Shape"(%525) {device = ""} : (tensor) -> tensor<2xi32> - %533 = "tf.StridedSlice"(%532, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %534 = "tf.ConcatV2"(%531, %529, %533, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %535 = "tf.Reshape"(%525, %534) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %536 = "tf.ExpandDims"(%520, %2) {device = ""} : (tensor, tensor) -> tensor - %537 = "tf.Less"(%523, %536) {device = ""} : (tensor, tensor) -> tensor - %538 = "tf.Reshape"(%537, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %539 = "tf.Where"(%538) {device = ""} : (tensor) -> tensor - %540 = "tf.Squeeze"(%539) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %541 = "tf.GatherV2"(%535, %540, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %542 = "tf.Cast"(%518) {Truncate = false, device = ""} : (tensor) -> tensor - %543 = "tf.BroadcastTo"(%542, %497) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %544 = "tf.Max"(%543, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %545 = "tf.Maximum"(%16, %544) {device = ""} : (tensor, tensor) -> tensor - %546 = "tf.Range"(%16, %545, %9) {device = ""} : (tensor, tensor, tensor) -> tensor - %547 = "tf.Pack"(%9, %545) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %548 = "tf.Tile"(%494, %547) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %549 = "tf.Shape"(%548) {device = ""} : (tensor) -> tensor<2xi32> - %550 = "tf.StridedSlice"(%549, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %551 = "tf.Prod"(%550, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %552 = "tf.Pack"(%551) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %553 = "tf.Shape"(%548) {device = ""} : (tensor) -> tensor<2xi32> - %554 = "tf.StridedSlice"(%553, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %555 = "tf.Shape"(%548) {device = ""} : (tensor) -> tensor<2xi32> - %556 = "tf.StridedSlice"(%555, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %557 = "tf.ConcatV2"(%554, %552, %556, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %558 = "tf.Reshape"(%548, %557) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %559 = "tf.ExpandDims"(%543, %2) {device = ""} : (tensor, tensor) -> tensor - %560 = "tf.Less"(%546, %559) {device = ""} : (tensor, tensor) -> tensor - %561 = "tf.Reshape"(%560, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %562 = "tf.Where"(%561) {device = ""} : (tensor) -> tensor - %563 = "tf.Squeeze"(%562) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %564 = "tf.GatherV2"(%558, %563, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %565:2 = "tf.RaggedRange"(%541, %564, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %566 = "tf.GatherV2"(%482, %565#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %567 = "tf.If"(%408, %408, %404, %15) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_false_112940, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_true_112930} : (tensor, tensor, tensor, tensor) -> tensor - %568 = "tf.Identity"(%567) {device = ""} : (tensor) -> tensor - %569 = "tf.Select"(%1, %404, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %570 = "tf.Pack"(%569) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %571 = "tf.ConcatV2"(%0, %570, %14, %16) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> - %572 = "tf.StridedSlice"(%571, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %573 = "tf.Equal"(%572, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %574 = "tf.StridedSlice"(%571, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %575 = "tf.StridedSlice"(%571, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %576 = "tf.Equal"(%575, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %577 = "tf.If"(%576, %576, %575, %518) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_false_113430, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_true_113420} : (tensor, tensor, tensor, tensor) -> tensor - %578 = "tf.Identity"(%577) {device = ""} : (tensor) -> tensor - %579 = "tf.If"(%573, %573, %518, %574) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_false_113790, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_true_113780} : (tensor, tensor, tensor, tensor) -> tensor - %580 = "tf.Identity"(%579) {device = ""} : (tensor) -> tensor - %581 = "tf.If"(%444, %444, %116, %440) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_118470, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_118460} : (tensor, tensor, tensor, tensor) -> tensor - %582 = "tf.Identity"(%581) {device = ""} : (tensor) -> tensor - %583 = "tf.Equal"(%440, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %584 = "tf.Select"(%583, %116, %440) {device = ""} : (tensor, tensor, tensor) -> tensor - %585 = "tf.Equal"(%584, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %586 = "tf.LogicalOr"(%585, %119) {device = ""} : (tensor, tensor) -> tensor - %587 = "tf.Equal"(%118, %584) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %588 = "tf.LogicalOr"(%586, %587) {device = ""} : (tensor, tensor) -> tensor - %589 = "tf.Select"(%451, %584, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %590 = "tf.Pack"(%589, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %591 = "tf.StridedSlice"(%590, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %592 = "tf.Cast"(%591) {Truncate = false, device = ""} : (tensor) -> tensor - %593 = "tf.Reshape"(%592, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %594 = "tf.Pack"(%9, %593) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %595 = "tf.Tile"(%452, %594) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %596 = "tf.Mul"(%593, %455) {device = ""} : (tensor, tensor) -> tensor - %597 = "tf.Pack"(%596) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %598 = "tf.ConcatV2"(%454, %597, %456, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %599 = "tf.Reshape"(%595, %598) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %600 = "tf.Shape"(%599) {device = ""} : (tensor) -> tensor<1xi64> - %601 = "tf.StridedSlice"(%600, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %602 = "tf.Pack"(%591) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %603 = "tf.StridedSlice"(%599, %602, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %604 = "tf.Sub"(%601, %591) {device = ""} : (tensor, tensor) -> tensor - %605 = "tf.Pack"(%604) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %606 = "tf.StridedSlice"(%599, %13, %605, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %607:2 = "tf.RaggedRange"(%606, %603, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %608 = "tf.Select"(%124, %584, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %609 = "tf.Pack"(%608, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %610 = "tf.StridedSlice"(%609, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %611 = "tf.Cast"(%610) {Truncate = false, device = ""} : (tensor) -> tensor - %612 = "tf.Reshape"(%611, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %613 = "tf.Pack"(%9, %612) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %614 = "tf.Tile"(%106, %613) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %615 = "tf.Mul"(%612, %109) {device = ""} : (tensor, tensor) -> tensor - %616 = "tf.Pack"(%615) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %617 = "tf.ConcatV2"(%108, %616, %110, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %618 = "tf.Reshape"(%614, %617) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %619 = "tf.Shape"(%618) {device = ""} : (tensor) -> tensor<1xi64> - %620 = "tf.StridedSlice"(%619, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %621 = "tf.Pack"(%610) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %622 = "tf.StridedSlice"(%618, %621, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %623 = "tf.Sub"(%620, %610) {device = ""} : (tensor, tensor) -> tensor - %624 = "tf.Pack"(%623) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %625 = "tf.StridedSlice"(%618, %13, %624, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %626:2 = "tf.RaggedRange"(%625, %622, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %627 = "tf.StridedSlice"(%609, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %628 = "tf.StridedSlice"(%609, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %629 = "tf.Mul"(%113, %628) {device = ""} : (tensor, tensor) -> tensor - %630 = "tf.Tile"(%629, %627) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %631 = "tf.Cumsum"(%630, %16) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %632 = "tf.ConcatV2"(%13, %631, %2) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %633 = "tf.StridedSlice"(%632, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %634 = "tf.ExpandDims"(%633, %9) {device = ""} : (tensor, tensor) -> tensor - %635 = "tf.Shape"(%633) {device = ""} : (tensor) -> tensor<1xi32> - %636 = "tf.StridedSlice"(%635, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %637 = "tf.Pack"(%636) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %638 = "tf.StridedSlice"(%632, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %639 = "tf.ExpandDims"(%638, %9) {device = ""} : (tensor, tensor) -> tensor - %640 = "tf.Shape"(%638) {device = ""} : (tensor) -> tensor<1xi32> - %641 = "tf.StridedSlice"(%640, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %642 = "tf.Pack"(%641) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %643 = "tf.Equal"(%440, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %644 = "tf.Select"(%643, %584, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %645 = "tf.Cast"(%644) {Truncate = false, device = ""} : (tensor) -> tensor - %646 = "tf.Reshape"(%645, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %647 = "tf.Pack"(%9, %646) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %648 = "tf.Mul"(%646, %10) {device = ""} : (tensor, tensor) -> tensor - %649 = "tf.Pack"(%648) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %650 = "tf.ConcatV2"(%11, %649, %11, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %651 = "tf.Pack"(%644) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %652 = "tf.Pack"(%12, %440) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %653 = "tf.ExpandDims"(%652, %9) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> - %654 = "tf.Tile"(%653, %647) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %655 = "tf.Reshape"(%654, %650) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %656 = "tf.Shape"(%655) {device = ""} : (tensor) -> tensor<1xi64> - %657 = "tf.StridedSlice"(%656, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %658 = "tf.Sub"(%657, %644) {device = ""} : (tensor, tensor) -> tensor - %659 = "tf.Pack"(%658) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %660 = "tf.StridedSlice"(%655, %13, %659, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %661 = "tf.StridedSlice"(%655, %651, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %662:2 = "tf.RaggedRange"(%660, %661, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %663 = "tf.GatherV2"(%447, %662#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %664 = "tf.Cast"(%663) {Truncate = false, device = ""} : (tensor) -> tensor - %665 = "tf.BroadcastTo"(%664, %637) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %666 = "tf.Max"(%665, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %667 = "tf.Maximum"(%16, %666) {device = ""} : (tensor, tensor) -> tensor - %668 = "tf.Range"(%16, %667, %9) {device = ""} : (tensor, tensor, tensor) -> tensor - %669 = "tf.Pack"(%9, %667) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %670 = "tf.Tile"(%634, %669) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %671 = "tf.Shape"(%670) {device = ""} : (tensor) -> tensor<2xi32> - %672 = "tf.StridedSlice"(%671, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %673 = "tf.Prod"(%672, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %674 = "tf.Pack"(%673) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %675 = "tf.Shape"(%670) {device = ""} : (tensor) -> tensor<2xi32> - %676 = "tf.StridedSlice"(%675, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %677 = "tf.Shape"(%670) {device = ""} : (tensor) -> tensor<2xi32> - %678 = "tf.StridedSlice"(%677, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %679 = "tf.ConcatV2"(%676, %674, %678, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %680 = "tf.Reshape"(%670, %679) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %681 = "tf.ExpandDims"(%665, %2) {device = ""} : (tensor, tensor) -> tensor - %682 = "tf.Less"(%668, %681) {device = ""} : (tensor, tensor) -> tensor - %683 = "tf.Reshape"(%682, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %684 = "tf.Where"(%683) {device = ""} : (tensor) -> tensor - %685 = "tf.Squeeze"(%684) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %686 = "tf.GatherV2"(%680, %685, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %687 = "tf.Cast"(%663) {Truncate = false, device = ""} : (tensor) -> tensor - %688 = "tf.BroadcastTo"(%687, %642) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %689 = "tf.Max"(%688, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %690 = "tf.Maximum"(%16, %689) {device = ""} : (tensor, tensor) -> tensor - %691 = "tf.Range"(%16, %690, %9) {device = ""} : (tensor, tensor, tensor) -> tensor - %692 = "tf.Pack"(%9, %690) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %693 = "tf.Tile"(%639, %692) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %694 = "tf.Shape"(%693) {device = ""} : (tensor) -> tensor<2xi32> - %695 = "tf.StridedSlice"(%694, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %696 = "tf.Prod"(%695, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %697 = "tf.Pack"(%696) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %698 = "tf.Shape"(%693) {device = ""} : (tensor) -> tensor<2xi32> - %699 = "tf.StridedSlice"(%698, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %700 = "tf.Shape"(%693) {device = ""} : (tensor) -> tensor<2xi32> - %701 = "tf.StridedSlice"(%700, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %702 = "tf.ConcatV2"(%699, %697, %701, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %703 = "tf.Reshape"(%693, %702) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %704 = "tf.ExpandDims"(%688, %2) {device = ""} : (tensor, tensor) -> tensor - %705 = "tf.Less"(%691, %704) {device = ""} : (tensor, tensor) -> tensor - %706 = "tf.Reshape"(%705, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %707 = "tf.Where"(%706) {device = ""} : (tensor) -> tensor - %708 = "tf.Squeeze"(%707) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %709 = "tf.GatherV2"(%703, %708, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %710:2 = "tf.RaggedRange"(%686, %709, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %711 = "tf.If"(%588, %588, %584, %120) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_119540, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_119530} : (tensor, tensor, tensor, tensor) -> tensor - %712 = "tf.Identity"(%711) {device = ""} : (tensor) -> tensor - %713 = "tf.StridedSlice"(%115, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %714 = "tf.Equal"(%713, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %715 = "tf.Select"(%714, %584, %713) {device = ""} : (tensor, tensor, tensor) -> tensor - %716 = "tf.Pack"(%715) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %717 = "tf.StridedSlice"(%115, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> - %718 = "tf.StridedSlice"(%115, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %719 = "tf.ConcatV2"(%717, %716, %718, %16) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> - %720 = "tf.StridedSlice"(%719, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %721 = "tf.Equal"(%720, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %722 = "tf.StridedSlice"(%719, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %723 = "tf.StridedSlice"(%719, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %724 = "tf.Equal"(%723, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %725 = "tf.If"(%724, %724, %723, %663) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_120030, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_120020} : (tensor, tensor, tensor, tensor) -> tensor - %726 = "tf.Identity"(%725) {device = ""} : (tensor) -> tensor - %727 = "tf.If"(%721, %721, %663, %722) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_120390, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_120380} : (tensor, tensor, tensor, tensor) -> tensor - %728 = "tf.Identity"(%168) {device = ""} : (tensor) -> tensor - %729 = "tf.Identity"(%727) {device = ""} : (tensor) -> tensor - %730 = "tf.Identity"(%400) {device = ""} : (tensor) -> tensor - %731 = "tf.Shape"(%125#2) {device = ""} : (tensor) -> tensor<1xi32> - %732 = "tf.StridedSlice"(%731, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %733 = "tf.Cast"(%732) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> - %734 = "tf.Identity"(%733) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> - %735 = "tf.Shape"(%125#3) {device = ""} : (tensor) -> tensor<1xi32> - %736 = "tf.StridedSlice"(%735, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %737 = "tf.Cast"(%736) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> - %738 = "tf.Identity"(%737) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> - %739 = "tf.GatherV2"(%125#3, %428, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %740 = "tf.Tile"(%739, %432) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %741 = "tf.Sub"(%740, %566) {device = ""} : (tensor, tensor) -> tensor - %742 = "tf.Shape"(%741) {device = ""} : (tensor) -> tensor<1xi32> - %743 = "tf.StridedSlice"(%742, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %744 = "tf.Cast"(%743) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> - %745 = "tf.Identity"(%744) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> - %746 = "tf.UnicodeEncode"(%125#0, %146) {Tsplits = i64, device = "", errors = "replace", output_encoding = "UTF-8", replacement_char = 65533 : i64} : (tensor, tensor) -> tensor - %747 = "tf.Identity"(%746) {device = ""} : (tensor) -> tensor - %748 = "tf.StridedSlice"(%19, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %749 = "tf.AddV2"(%748, %15) {device = ""} : (tensor, tensor) -> tensor - %750 = "tf.Range"(%12, %749, %15) {device = ""} : (tensor, tensor, tensor) -> tensor - %751 = "tf.Mul"(%750, %15) {device = ""} : (tensor, tensor) -> tensor - %752 = "tf.Identity"(%751) {device = ""} : (tensor) -> tensor - return %747, %752, %728 : tensor, tensor, tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_99640(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Input tensors have incompatible shapes."> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedConcat/RaggedFromTensor/strided_slice_4:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedConcat/RaggedNRows/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_99630(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_100400(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_100390(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_100760(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_100750(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_101100(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_101090(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_101470(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_101460(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_101830(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_101820(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_102190(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RaggedNRows/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_102180(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_102540(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_102530(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_102900(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_102890(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_103240(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_103230(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_103610(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_103600(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_103970(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_103960(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_104310(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_104300(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_105110(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_105100(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_106180(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_106170(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_106670(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_106660(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_107030(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_107020(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_false_111870(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_true_111860(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_false_112940(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_true_112930(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_false_113430(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_true_113420(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_false_113790(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_true_113780(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_118470(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_118460(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_119540(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_119530(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_120030(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_120020(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_120390(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_120380(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - - - // CHECK: func @whitespace_tokenizer_rank2(%arg0: tensor {tf._user_specified_name = "input"}) -> (tensor, tensor, tensor) attributes {tf._input_shapes = [#tf.shape], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} { - // CHECK: %0:3 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor) -> (tensor, tensor, tensor) - // CHECK: return %0#0, %0#1, %0#2 : tensor, tensor, tensor - - func @whitespace_tokenizer_rank0(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor attributes {tf._input_shapes = [#tf.shape<>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> - %1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64> - %2 = "tf.Const"() {value = dense : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - %4 = "tf.Const"() {value = dense<[[0], [1]]> : tensor<2x1xi64>} : () -> tensor<2x1xi64> - %5 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> - %6 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %7 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %8 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor - %9 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> - %10 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %11 = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64> - %12 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> - %13 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %14 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %15 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %16 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %17 = "tf.If"(%2, %2, %13, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_3220, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_3210} : (tensor, tensor, tensor, tensor) -> tensor - %18 = "tf.Identity"(%17) {device = ""} : (tensor) -> tensor - %19 = "tf.Pack"(%arg0) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1x!tf.string> - %20 = "tf.StringLength"(%19) {device = "", unit = "BYTE"} : (tensor<1x!tf.string>) -> tensor<1xi32> - %21 = "tf.ExpandDims"(%20, %7) {device = ""} : (tensor<1xi32>, tensor) -> tensor<1x1xi32> - %22 = "tf.Cast"(%21) {Truncate = false, device = ""} : (tensor<1x1xi32>) -> tensor<1x1xi64> - %23 = "tf.Reshape"(%22, %12) {device = ""} : (tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64> - %24 = "tf.Reshape"(%19, %5) {device = ""} : (tensor<1x!tf.string>, tensor<1xi32>) -> tensor<1x!tf.string> - %25:3 = "tf.UnicodeDecodeWithOffsets"(%24) {Tsplits = i64, device = "", errors = "replace", input_encoding = "UTF-8", replace_control_characters = false, replacement_char = 65533 : i64} : (tensor<1x!tf.string>) -> (tensor<2xi64>, tensor, tensor) - %26 = "tf.StridedSlice"(%25#0, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %27 = "tf.AddV2"(%26, %13) {device = ""} : (tensor<1xi64>, tensor) -> tensor<1xi64> - %28 = "tf.StridedSlice"(%25#0, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %29 = "tf.Minimum"(%27, %28) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> - %30:2 = "tf.RaggedRange"(%29, %28, %13) {T = i64, Tsplits = i64, device = ""} : (tensor<1xi64>, tensor<1xi64>, tensor) -> (tensor<2xi64>, tensor) - %31 = "tf.StridedSlice"(%30#0, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %32 = "tf.AddV2"(%31, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> - %33 = "tf.ConcatV2"(%30#0, %32, %14) {device = ""} : (tensor<2xi64>, tensor<1xi64>, tensor) -> tensor<3xi64> - %34 = "tf.GatherV2"(%25#2, %30#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %35 = "tf.ConcatV2"(%34, %23, %14) {device = ""} : (tensor, tensor<1xi64>, tensor) -> tensor - %36:2 = "tf.RaggedGather"(%33, %35, %0) {OUTPUT_RAGGED_RANK = 1 : i64, PARAMS_RAGGED_RANK = 1 : i64, Tindices = i64, Tsplits = i64, Tvalues = i64, device = ""} : (tensor<3xi64>, tensor, tensor<2xi64>) -> (tensor, tensor) - %37:5 = "tf.WhitespaceTokenizeWithOffsets"(%25#1, %25#0) {Tsplits = i64, device = ""} : (tensor, tensor<2xi64>) -> (tensor, tensor, tensor, tensor, tensor) - %38 = "tf.StridedSlice"(%37#1, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %39 = "tf.Equal"(%38, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %40 = "tf.All"(%39, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %41 = "tf.If"(%40, %40, %38, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_3980, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_3970} : (tensor, tensor, tensor, tensor) -> tensor - %42 = "tf.Identity"(%41) {device = ""} : (tensor) -> tensor - %43 = "tf.StridedSlice"(%37#1, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %44 = "tf.StridedSlice"(%37#1, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %45 = "tf.Sub"(%43, %44) {device = ""} : (tensor, tensor) -> tensor - %46 = "tf.LessEqual"(%10, %45) {device = ""} : (tensor, tensor) -> tensor - %47 = "tf.All"(%46, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %48 = "tf.If"(%47, %47, %45) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_4340, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_4330} : (tensor, tensor, tensor) -> tensor - %49 = "tf.Identity"(%48) {device = ""} : (tensor) -> tensor - %50 = "tf.Identity"(%37#1) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %51 = "tf.StridedSlice"(%50, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %52 = "tf.Shape"(%37#0) {device = ""} : (tensor) -> tensor<1xi64> - %53 = "tf.StridedSlice"(%52, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %54 = "tf.Equal"(%51, %53) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %55 = "tf.All"(%54, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %56 = "tf.If"(%55, %55, %51, %53) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_4680, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_4670} : (tensor, tensor, tensor, tensor) -> tensor - %57 = "tf.Identity"(%56) {device = ""} : (tensor) -> tensor - %58 = "tf.Identity"(%50) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %59 = "tf.Shape"(%58) {device = ""} : (tensor) -> tensor<1xi64> - %60 = "tf.StridedSlice"(%59, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %61 = "tf.Sub"(%60, %13) {device = ""} : (tensor, tensor) -> tensor - %62 = "tf.StridedSlice"(%37#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %63 = "tf.Equal"(%62, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %64 = "tf.All"(%63, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %65 = "tf.If"(%64, %64, %62, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_5050, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_5040} : (tensor, tensor, tensor, tensor) -> tensor - %66 = "tf.Identity"(%65) {device = ""} : (tensor) -> tensor - %67 = "tf.StridedSlice"(%37#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %68 = "tf.StridedSlice"(%37#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %69 = "tf.Sub"(%67, %68) {device = ""} : (tensor, tensor) -> tensor - %70 = "tf.LessEqual"(%10, %69) {device = ""} : (tensor, tensor) -> tensor - %71 = "tf.All"(%70, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %72 = "tf.If"(%71, %71, %69) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_5410, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_5400} : (tensor, tensor, tensor) -> tensor - %73 = "tf.Identity"(%72) {device = ""} : (tensor) -> tensor - %74 = "tf.Identity"(%37#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %75 = "tf.StridedSlice"(%74, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %76 = "tf.Equal"(%75, %61) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %77 = "tf.All"(%76, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %78 = "tf.If"(%77, %77, %75, %61) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_5770, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_5760} : (tensor, tensor, tensor, tensor) -> tensor - %79 = "tf.Identity"(%78) {device = ""} : (tensor) -> tensor - %80 = "tf.Identity"(%74) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %81 = "tf.StridedSlice"(%37#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %82 = "tf.Equal"(%81, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %83 = "tf.All"(%82, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %84 = "tf.If"(%83, %83, %81, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6120, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6110} : (tensor, tensor, tensor, tensor) -> tensor - %85 = "tf.Identity"(%84) {device = ""} : (tensor) -> tensor - %86 = "tf.StridedSlice"(%37#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %87 = "tf.StridedSlice"(%37#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %88 = "tf.Sub"(%86, %87) {device = ""} : (tensor, tensor) -> tensor - %89 = "tf.LessEqual"(%10, %88) {device = ""} : (tensor, tensor) -> tensor - %90 = "tf.All"(%89, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %91 = "tf.If"(%90, %90, %88) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_6480, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_6470} : (tensor, tensor, tensor) -> tensor - %92 = "tf.Identity"(%91) {device = ""} : (tensor) -> tensor - %93 = "tf.Identity"(%37#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %94 = "tf.StridedSlice"(%93, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %95 = "tf.Shape"(%37#2) {device = ""} : (tensor) -> tensor<1xi64> - %96 = "tf.StridedSlice"(%95, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %97 = "tf.Equal"(%94, %96) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %98 = "tf.All"(%97, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %99 = "tf.If"(%98, %98, %94, %96) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6820, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6810} : (tensor, tensor, tensor, tensor) -> tensor - %100 = "tf.Identity"(%99) {device = ""} : (tensor) -> tensor - %101 = "tf.Identity"(%93) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %102 = "tf.Shape"(%101) {device = ""} : (tensor) -> tensor<1xi64> - %103 = "tf.StridedSlice"(%102, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %104 = "tf.Sub"(%103, %13) {device = ""} : (tensor, tensor) -> tensor - %105 = "tf.Equal"(%104, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %106 = "tf.LogicalOr"(%105, %2) {device = ""} : (tensor, tensor) -> tensor - %107 = "tf.Equal"(%104, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %108 = "tf.LogicalOr"(%106, %107) {device = ""} : (tensor, tensor) -> tensor - %109 = "tf.StridedSlice"(%101, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %110 = "tf.StridedSlice"(%101, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %111 = "tf.Sub"(%109, %110) {device = ""} : (tensor, tensor) -> tensor - %112 = "tf.Shape"(%101) {device = ""} : (tensor) -> tensor<1xi64> - %113 = "tf.StridedSlice"(%112, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %114 = "tf.Sub"(%113, %13) {device = ""} : (tensor, tensor) -> tensor - %115 = "tf.Equal"(%114, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %116 = "tf.ExpandDims"(%101, %7) {device = ""} : (tensor, tensor) -> tensor - %117 = "tf.Shape"(%101) {device = ""} : (tensor) -> tensor<1xi32> - %118 = "tf.StridedSlice"(%117, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %119 = "tf.StridedSlice"(%117, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %120 = "tf.StridedSlice"(%117, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %121 = "tf.StridedSlice"(%37#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %122 = "tf.Equal"(%121, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %123 = "tf.All"(%122, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %124 = "tf.If"(%123, %123, %121, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7190, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7180} : (tensor, tensor, tensor, tensor) -> tensor - %125 = "tf.Identity"(%124) {device = ""} : (tensor) -> tensor - %126 = "tf.StridedSlice"(%37#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %127 = "tf.StridedSlice"(%37#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %128 = "tf.Sub"(%126, %127) {device = ""} : (tensor, tensor) -> tensor - %129 = "tf.LessEqual"(%10, %128) {device = ""} : (tensor, tensor) -> tensor - %130 = "tf.All"(%129, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %131 = "tf.If"(%130, %130, %128) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_7550, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_7540} : (tensor, tensor, tensor) -> tensor - %132 = "tf.Identity"(%131) {device = ""} : (tensor) -> tensor - %133 = "tf.Identity"(%37#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %134 = "tf.StridedSlice"(%133, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %135 = "tf.Shape"(%37#3) {device = ""} : (tensor) -> tensor<1xi64> - %136 = "tf.StridedSlice"(%135, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %137 = "tf.Equal"(%134, %136) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %138 = "tf.All"(%137, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor - %139 = "tf.If"(%138, %138, %134, %136) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7890, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7880} : (tensor, tensor, tensor, tensor) -> tensor - %140 = "tf.Identity"(%139) {device = ""} : (tensor) -> tensor - %141 = "tf.Identity"(%133) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor - %142 = "tf.Shape"(%141) {device = ""} : (tensor) -> tensor<1xi64> - %143 = "tf.StridedSlice"(%142, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %144 = "tf.Sub"(%143, %13) {device = ""} : (tensor, tensor) -> tensor - %145 = "tf.Equal"(%144, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %146 = "tf.LogicalOr"(%145, %2) {device = ""} : (tensor, tensor) -> tensor - %147 = "tf.Equal"(%144, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %148 = "tf.LogicalOr"(%146, %147) {device = ""} : (tensor, tensor) -> tensor - %149 = "tf.StridedSlice"(%141, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %150 = "tf.StridedSlice"(%141, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %151 = "tf.Sub"(%149, %150) {device = ""} : (tensor, tensor) -> tensor - %152 = "tf.Shape"(%141) {device = ""} : (tensor) -> tensor<1xi64> - %153 = "tf.StridedSlice"(%152, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %154 = "tf.Sub"(%153, %13) {device = ""} : (tensor, tensor) -> tensor - %155 = "tf.Equal"(%154, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %156 = "tf.ExpandDims"(%141, %7) {device = ""} : (tensor, tensor) -> tensor - %157 = "tf.Shape"(%141) {device = ""} : (tensor) -> tensor<1xi32> - %158 = "tf.StridedSlice"(%157, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %159 = "tf.StridedSlice"(%157, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %160 = "tf.StridedSlice"(%157, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %161 = "tf.StridedSlice"(%141, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %162 = "tf.Range"(%10, %161, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %163 = "tf.StridedSlice"(%141, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %164 = "tf.StridedSlice"(%141, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %165 = "tf.Sub"(%163, %164) {device = ""} : (tensor, tensor) -> tensor - %166 = "tf.If"(%108, %108, %13, %104) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_8690, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_8680} : (tensor, tensor, tensor, tensor) -> tensor - %167 = "tf.Identity"(%166) {device = ""} : (tensor) -> tensor - %168 = "tf.Equal"(%104, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %169 = "tf.Select"(%168, %13, %104) {device = ""} : (tensor, tensor, tensor) -> tensor - %170 = "tf.Equal"(%169, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %171 = "tf.LogicalOr"(%170, %2) {device = ""} : (tensor, tensor) -> tensor - %172 = "tf.Equal"(%169, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %173 = "tf.LogicalOr"(%171, %172) {device = ""} : (tensor, tensor) -> tensor - %174 = "tf.Select"(%115, %169, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %175 = "tf.Pack"(%174, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %176 = "tf.StridedSlice"(%175, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %177 = "tf.Cast"(%176) {Truncate = false, device = ""} : (tensor) -> tensor - %178 = "tf.Reshape"(%177, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %179 = "tf.Pack"(%7, %178) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %180 = "tf.Tile"(%116, %179) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %181 = "tf.Mul"(%178, %119) {device = ""} : (tensor, tensor) -> tensor - %182 = "tf.Pack"(%181) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %183 = "tf.ConcatV2"(%118, %182, %120, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %184 = "tf.Reshape"(%180, %183) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %185 = "tf.Shape"(%184) {device = ""} : (tensor) -> tensor<1xi64> - %186 = "tf.StridedSlice"(%185, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %187 = "tf.Pack"(%176) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %188 = "tf.StridedSlice"(%184, %187, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %189 = "tf.Sub"(%186, %176) {device = ""} : (tensor, tensor) -> tensor - %190 = "tf.Pack"(%189) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %191 = "tf.StridedSlice"(%184, %11, %190, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %192:2 = "tf.RaggedRange"(%191, %188, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %193 = "tf.Select"(%2, %169, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %194 = "tf.Pack"(%193, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %195 = "tf.StridedSlice"(%194, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %196 = "tf.Cast"(%195) {Truncate = false, device = ""} : (tensor) -> tensor - %197 = "tf.Reshape"(%196, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %198 = "tf.Pack"(%7, %197) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %199 = "tf.Tile"(%4, %198) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %200 = "tf.Mul"(%197, %8) {device = ""} : (tensor, tensor) -> tensor - %201 = "tf.Pack"(%200) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %202 = "tf.ConcatV2"(%9, %201, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %203 = "tf.Reshape"(%199, %202) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %204 = "tf.Shape"(%203) {device = ""} : (tensor) -> tensor<1xi64> - %205 = "tf.StridedSlice"(%204, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %206 = "tf.Pack"(%195) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %207 = "tf.StridedSlice"(%203, %206, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %208 = "tf.Sub"(%205, %195) {device = ""} : (tensor, tensor) -> tensor - %209 = "tf.Pack"(%208) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %210 = "tf.StridedSlice"(%203, %11, %209, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %211:2 = "tf.RaggedRange"(%210, %207, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %212 = "tf.StridedSlice"(%194, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %213 = "tf.StridedSlice"(%194, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %214 = "tf.Mul"(%213, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> - %215 = "tf.Tile"(%214, %212) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor - %216 = "tf.Cumsum"(%215, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %217 = "tf.ConcatV2"(%11, %216, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %218 = "tf.StridedSlice"(%217, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %219 = "tf.ExpandDims"(%218, %7) {device = ""} : (tensor, tensor) -> tensor - %220 = "tf.Shape"(%218) {device = ""} : (tensor) -> tensor<1xi32> - %221 = "tf.StridedSlice"(%220, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %222 = "tf.Pack"(%221) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %223 = "tf.StridedSlice"(%217, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %224 = "tf.ExpandDims"(%223, %7) {device = ""} : (tensor, tensor) -> tensor - %225 = "tf.Shape"(%223) {device = ""} : (tensor) -> tensor<1xi32> - %226 = "tf.StridedSlice"(%225, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %227 = "tf.Pack"(%226) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %228 = "tf.Equal"(%104, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %229 = "tf.Select"(%228, %169, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %230 = "tf.Cast"(%229) {Truncate = false, device = ""} : (tensor) -> tensor - %231 = "tf.Reshape"(%230, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %232 = "tf.Pack"(%7, %231) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %233 = "tf.Mul"(%231, %8) {device = ""} : (tensor, tensor) -> tensor - %234 = "tf.Pack"(%233) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %235 = "tf.ConcatV2"(%9, %234, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %236 = "tf.Pack"(%229) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %237 = "tf.Pack"(%10, %104) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %238 = "tf.ExpandDims"(%237, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> - %239 = "tf.Tile"(%238, %232) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %240 = "tf.Reshape"(%239, %235) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %241 = "tf.Shape"(%240) {device = ""} : (tensor) -> tensor<1xi64> - %242 = "tf.StridedSlice"(%241, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %243 = "tf.Sub"(%242, %229) {device = ""} : (tensor, tensor) -> tensor - %244 = "tf.Pack"(%243) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %245 = "tf.StridedSlice"(%240, %11, %244, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %246 = "tf.StridedSlice"(%240, %236, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %247:2 = "tf.RaggedRange"(%245, %246, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %248 = "tf.GatherV2"(%111, %247#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %249 = "tf.Cast"(%248) {Truncate = false, device = ""} : (tensor) -> tensor - %250 = "tf.BroadcastTo"(%249, %222) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %251 = "tf.Max"(%250, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %252 = "tf.Maximum"(%14, %251) {device = ""} : (tensor, tensor) -> tensor - %253 = "tf.Range"(%14, %252, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %254 = "tf.Pack"(%7, %252) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %255 = "tf.Tile"(%219, %254) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %256 = "tf.Shape"(%255) {device = ""} : (tensor) -> tensor<2xi32> - %257 = "tf.StridedSlice"(%256, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %258 = "tf.Prod"(%257, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %259 = "tf.Pack"(%258) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %260 = "tf.Shape"(%255) {device = ""} : (tensor) -> tensor<2xi32> - %261 = "tf.StridedSlice"(%260, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %262 = "tf.Shape"(%255) {device = ""} : (tensor) -> tensor<2xi32> - %263 = "tf.StridedSlice"(%262, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %264 = "tf.ConcatV2"(%261, %259, %263, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %265 = "tf.Reshape"(%255, %264) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %266 = "tf.ExpandDims"(%250, %3) {device = ""} : (tensor, tensor) -> tensor - %267 = "tf.Less"(%253, %266) {device = ""} : (tensor, tensor) -> tensor - %268 = "tf.Reshape"(%267, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %269 = "tf.Where"(%268) {device = ""} : (tensor) -> tensor - %270 = "tf.Squeeze"(%269) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %271 = "tf.GatherV2"(%265, %270, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %272 = "tf.Cast"(%248) {Truncate = false, device = ""} : (tensor) -> tensor - %273 = "tf.BroadcastTo"(%272, %227) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %274 = "tf.Max"(%273, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %275 = "tf.Maximum"(%14, %274) {device = ""} : (tensor, tensor) -> tensor - %276 = "tf.Range"(%14, %275, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %277 = "tf.Pack"(%7, %275) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %278 = "tf.Tile"(%224, %277) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %279 = "tf.Shape"(%278) {device = ""} : (tensor) -> tensor<2xi32> - %280 = "tf.StridedSlice"(%279, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %281 = "tf.Prod"(%280, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %282 = "tf.Pack"(%281) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %283 = "tf.Shape"(%278) {device = ""} : (tensor) -> tensor<2xi32> - %284 = "tf.StridedSlice"(%283, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %285 = "tf.Shape"(%278) {device = ""} : (tensor) -> tensor<2xi32> - %286 = "tf.StridedSlice"(%285, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %287 = "tf.ConcatV2"(%284, %282, %286, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %288 = "tf.Reshape"(%278, %287) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %289 = "tf.ExpandDims"(%273, %3) {device = ""} : (tensor, tensor) -> tensor - %290 = "tf.Less"(%276, %289) {device = ""} : (tensor, tensor) -> tensor - %291 = "tf.Reshape"(%290, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %292 = "tf.Where"(%291) {device = ""} : (tensor) -> tensor - %293 = "tf.Squeeze"(%292) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %294 = "tf.GatherV2"(%288, %293, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %295:2 = "tf.RaggedRange"(%271, %294, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %296 = "tf.If"(%173, %173, %169, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_9760, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_9750} : (tensor, tensor, tensor, tensor) -> tensor - %297 = "tf.Identity"(%296) {device = ""} : (tensor) -> tensor - %298 = "tf.Select"(%2, %169, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %299 = "tf.Pack"(%298) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %300 = "tf.ConcatV2"(%1, %299, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> - %301 = "tf.StridedSlice"(%300, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %302 = "tf.Equal"(%301, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %303 = "tf.StridedSlice"(%300, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %304 = "tf.StridedSlice"(%300, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %305 = "tf.Equal"(%304, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %306 = "tf.If"(%305, %305, %304, %248) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_10250, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_10240} : (tensor, tensor, tensor, tensor) -> tensor - %307 = "tf.Identity"(%306) {device = ""} : (tensor) -> tensor - %308 = "tf.If"(%302, %302, %248, %303) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_10610, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_10600} : (tensor, tensor, tensor, tensor) -> tensor - %309 = "tf.If"(%148, %148, %13, %144) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_false_15310, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_true_15300} : (tensor, tensor, tensor, tensor) -> tensor - %310 = "tf.Identity"(%309) {device = ""} : (tensor) -> tensor - %311 = "tf.Equal"(%144, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %312 = "tf.Select"(%311, %13, %144) {device = ""} : (tensor, tensor, tensor) -> tensor - %313 = "tf.Equal"(%312, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %314 = "tf.LogicalOr"(%313, %2) {device = ""} : (tensor, tensor) -> tensor - %315 = "tf.Equal"(%312, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %316 = "tf.LogicalOr"(%314, %315) {device = ""} : (tensor, tensor) -> tensor - %317 = "tf.Select"(%155, %312, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %318 = "tf.Pack"(%317, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %319 = "tf.StridedSlice"(%318, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %320 = "tf.Cast"(%319) {Truncate = false, device = ""} : (tensor) -> tensor - %321 = "tf.Reshape"(%320, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %322 = "tf.Pack"(%7, %321) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %323 = "tf.Tile"(%156, %322) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %324 = "tf.Mul"(%321, %159) {device = ""} : (tensor, tensor) -> tensor - %325 = "tf.Pack"(%324) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %326 = "tf.ConcatV2"(%158, %325, %160, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %327 = "tf.Reshape"(%323, %326) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %328 = "tf.Shape"(%327) {device = ""} : (tensor) -> tensor<1xi64> - %329 = "tf.StridedSlice"(%328, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %330 = "tf.Pack"(%319) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %331 = "tf.StridedSlice"(%327, %330, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %332 = "tf.Sub"(%329, %319) {device = ""} : (tensor, tensor) -> tensor - %333 = "tf.Pack"(%332) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %334 = "tf.StridedSlice"(%327, %11, %333, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %335:2 = "tf.RaggedRange"(%334, %331, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %336 = "tf.GatherV2"(%162, %335#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %337 = "tf.StridedSlice"(%318, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %338 = "tf.StridedSlice"(%318, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %339 = "tf.StridedSlice"(%318, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> - %340 = "tf.ConcatV2"(%338, %339, %14) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> - %341 = "tf.StridedSlice"(%318, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %342 = "tf.Mul"(%165, %341) {device = ""} : (tensor, tensor) -> tensor - %343 = "tf.Tile"(%342, %337) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %344 = "tf.Cumsum"(%343, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %345 = "tf.ConcatV2"(%11, %344, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %346 = "tf.Shape"(%345) {device = ""} : (tensor) -> tensor<1xi64> - %347 = "tf.StridedSlice"(%346, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %348 = "tf.Sub"(%347, %13) {device = ""} : (tensor, tensor) -> tensor - %349 = "tf.Equal"(%348, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %350 = "tf.LogicalOr"(%349, %2) {device = ""} : (tensor, tensor) -> tensor - %351 = "tf.Equal"(%348, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %352 = "tf.LogicalOr"(%350, %351) {device = ""} : (tensor, tensor) -> tensor - %353 = "tf.StridedSlice"(%345, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %354 = "tf.StridedSlice"(%345, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %355 = "tf.Sub"(%353, %354) {device = ""} : (tensor, tensor) -> tensor - %356 = "tf.Shape"(%345) {device = ""} : (tensor) -> tensor<1xi64> - %357 = "tf.StridedSlice"(%356, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %358 = "tf.Sub"(%357, %13) {device = ""} : (tensor, tensor) -> tensor - %359 = "tf.Equal"(%358, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %360 = "tf.ExpandDims"(%345, %7) {device = ""} : (tensor, tensor) -> tensor - %361 = "tf.Shape"(%345) {device = ""} : (tensor) -> tensor<1xi32> - %362 = "tf.StridedSlice"(%361, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %363 = "tf.StridedSlice"(%361, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %364 = "tf.StridedSlice"(%361, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %365 = "tf.Select"(%2, %312, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %366 = "tf.Pack"(%365, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %367 = "tf.StridedSlice"(%366, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %368 = "tf.Cast"(%367) {Truncate = false, device = ""} : (tensor) -> tensor - %369 = "tf.Reshape"(%368, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %370 = "tf.Pack"(%7, %369) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %371 = "tf.Tile"(%4, %370) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %372 = "tf.Mul"(%369, %8) {device = ""} : (tensor, tensor) -> tensor - %373 = "tf.Pack"(%372) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %374 = "tf.ConcatV2"(%9, %373, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %375 = "tf.Reshape"(%371, %374) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %376 = "tf.Shape"(%375) {device = ""} : (tensor) -> tensor<1xi64> - %377 = "tf.StridedSlice"(%376, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %378 = "tf.Pack"(%367) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %379 = "tf.StridedSlice"(%375, %378, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %380 = "tf.Sub"(%377, %367) {device = ""} : (tensor, tensor) -> tensor - %381 = "tf.Pack"(%380) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %382 = "tf.StridedSlice"(%375, %11, %381, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %383:2 = "tf.RaggedRange"(%382, %379, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %384 = "tf.GatherV2"(%11, %383#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %385 = "tf.GatherV2"(%12, %384, %14) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %386 = "tf.StridedSlice"(%366, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %387 = "tf.StridedSlice"(%366, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %388 = "tf.StridedSlice"(%366, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> - %389 = "tf.ConcatV2"(%387, %388, %14) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> - %390 = "tf.Tile"(%385, %389) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %391 = "tf.StridedSlice"(%366, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %392 = "tf.Mul"(%391, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> - %393 = "tf.Tile"(%392, %386) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor - %394 = "tf.Cumsum"(%393, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %395 = "tf.ConcatV2"(%11, %394, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %396 = "tf.StridedSlice"(%395, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %397 = "tf.ExpandDims"(%396, %7) {device = ""} : (tensor, tensor) -> tensor - %398 = "tf.Shape"(%396) {device = ""} : (tensor) -> tensor<1xi32> - %399 = "tf.StridedSlice"(%398, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %400 = "tf.Pack"(%399) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %401 = "tf.StridedSlice"(%395, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %402 = "tf.ExpandDims"(%401, %7) {device = ""} : (tensor, tensor) -> tensor - %403 = "tf.Shape"(%401) {device = ""} : (tensor) -> tensor<1xi32> - %404 = "tf.StridedSlice"(%403, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %405 = "tf.Pack"(%404) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %406 = "tf.Equal"(%144, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %407 = "tf.Select"(%406, %312, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %408 = "tf.Cast"(%407) {Truncate = false, device = ""} : (tensor) -> tensor - %409 = "tf.Reshape"(%408, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %410 = "tf.Pack"(%7, %409) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %411 = "tf.Mul"(%409, %8) {device = ""} : (tensor, tensor) -> tensor - %412 = "tf.Pack"(%411) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %413 = "tf.ConcatV2"(%9, %412, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %414 = "tf.Pack"(%407) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %415 = "tf.Pack"(%10, %144) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %416 = "tf.ExpandDims"(%415, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> - %417 = "tf.Tile"(%416, %410) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %418 = "tf.Reshape"(%417, %413) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %419 = "tf.Shape"(%418) {device = ""} : (tensor) -> tensor<1xi64> - %420 = "tf.StridedSlice"(%419, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %421 = "tf.Sub"(%420, %407) {device = ""} : (tensor, tensor) -> tensor - %422 = "tf.Pack"(%421) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %423 = "tf.StridedSlice"(%418, %11, %422, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %424 = "tf.StridedSlice"(%418, %414, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %425:2 = "tf.RaggedRange"(%423, %424, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %426 = "tf.GatherV2"(%151, %425#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %427 = "tf.Cast"(%426) {Truncate = false, device = ""} : (tensor) -> tensor - %428 = "tf.BroadcastTo"(%427, %400) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %429 = "tf.Max"(%428, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %430 = "tf.Maximum"(%14, %429) {device = ""} : (tensor, tensor) -> tensor - %431 = "tf.Range"(%14, %430, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %432 = "tf.Pack"(%7, %430) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %433 = "tf.Tile"(%397, %432) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %434 = "tf.Shape"(%433) {device = ""} : (tensor) -> tensor<2xi32> - %435 = "tf.StridedSlice"(%434, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %436 = "tf.Prod"(%435, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %437 = "tf.Pack"(%436) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %438 = "tf.Shape"(%433) {device = ""} : (tensor) -> tensor<2xi32> - %439 = "tf.StridedSlice"(%438, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %440 = "tf.Shape"(%433) {device = ""} : (tensor) -> tensor<2xi32> - %441 = "tf.StridedSlice"(%440, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %442 = "tf.ConcatV2"(%439, %437, %441, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %443 = "tf.Reshape"(%433, %442) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %444 = "tf.ExpandDims"(%428, %3) {device = ""} : (tensor, tensor) -> tensor - %445 = "tf.Less"(%431, %444) {device = ""} : (tensor, tensor) -> tensor - %446 = "tf.Reshape"(%445, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %447 = "tf.Where"(%446) {device = ""} : (tensor) -> tensor - %448 = "tf.Squeeze"(%447) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %449 = "tf.GatherV2"(%443, %448, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %450 = "tf.Cast"(%426) {Truncate = false, device = ""} : (tensor) -> tensor - %451 = "tf.BroadcastTo"(%450, %405) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %452 = "tf.Max"(%451, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %453 = "tf.Maximum"(%14, %452) {device = ""} : (tensor, tensor) -> tensor - %454 = "tf.Range"(%14, %453, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %455 = "tf.Pack"(%7, %453) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %456 = "tf.Tile"(%402, %455) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %457 = "tf.Shape"(%456) {device = ""} : (tensor) -> tensor<2xi32> - %458 = "tf.StridedSlice"(%457, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %459 = "tf.Prod"(%458, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %460 = "tf.Pack"(%459) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %461 = "tf.Shape"(%456) {device = ""} : (tensor) -> tensor<2xi32> - %462 = "tf.StridedSlice"(%461, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %463 = "tf.Shape"(%456) {device = ""} : (tensor) -> tensor<2xi32> - %464 = "tf.StridedSlice"(%463, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %465 = "tf.ConcatV2"(%462, %460, %464, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %466 = "tf.Reshape"(%456, %465) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %467 = "tf.ExpandDims"(%451, %3) {device = ""} : (tensor, tensor) -> tensor - %468 = "tf.Less"(%454, %467) {device = ""} : (tensor, tensor) -> tensor - %469 = "tf.Reshape"(%468, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %470 = "tf.Where"(%469) {device = ""} : (tensor) -> tensor - %471 = "tf.Squeeze"(%470) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %472 = "tf.GatherV2"(%466, %471, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %473:2 = "tf.RaggedRange"(%449, %472, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %474 = "tf.GatherV2"(%390, %473#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %475 = "tf.If"(%316, %316, %312, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_false_16380, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_true_16370} : (tensor, tensor, tensor, tensor) -> tensor - %476 = "tf.Identity"(%475) {device = ""} : (tensor) -> tensor - %477 = "tf.Select"(%2, %312, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %478 = "tf.Pack"(%477) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %479 = "tf.ConcatV2"(%1, %478, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> - %480 = "tf.StridedSlice"(%479, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %481 = "tf.Equal"(%480, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %482 = "tf.StridedSlice"(%479, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %483 = "tf.StridedSlice"(%479, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %484 = "tf.Equal"(%483, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %485 = "tf.If"(%484, %484, %483, %426) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_false_16870, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_true_16860} : (tensor, tensor, tensor, tensor) -> tensor - %486 = "tf.Identity"(%485) {device = ""} : (tensor) -> tensor - %487 = "tf.If"(%481, %481, %426, %482) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_false_17230, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_true_17220} : (tensor, tensor, tensor, tensor) -> tensor - %488 = "tf.Identity"(%487) {device = ""} : (tensor) -> tensor - %489 = "tf.If"(%352, %352, %13, %348) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_21910, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_21900} : (tensor, tensor, tensor, tensor) -> tensor - %490 = "tf.Identity"(%489) {device = ""} : (tensor) -> tensor - %491 = "tf.Equal"(%348, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %492 = "tf.Select"(%491, %13, %348) {device = ""} : (tensor, tensor, tensor) -> tensor - %493 = "tf.Equal"(%492, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %494 = "tf.LogicalOr"(%493, %2) {device = ""} : (tensor, tensor) -> tensor - %495 = "tf.Equal"(%492, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %496 = "tf.LogicalOr"(%494, %495) {device = ""} : (tensor, tensor) -> tensor - %497 = "tf.Select"(%359, %492, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %498 = "tf.Pack"(%497, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %499 = "tf.StridedSlice"(%498, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %500 = "tf.Cast"(%499) {Truncate = false, device = ""} : (tensor) -> tensor - %501 = "tf.Reshape"(%500, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %502 = "tf.Pack"(%7, %501) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %503 = "tf.Tile"(%360, %502) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %504 = "tf.Mul"(%501, %363) {device = ""} : (tensor, tensor) -> tensor - %505 = "tf.Pack"(%504) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %506 = "tf.ConcatV2"(%362, %505, %364, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %507 = "tf.Reshape"(%503, %506) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %508 = "tf.Shape"(%507) {device = ""} : (tensor) -> tensor<1xi64> - %509 = "tf.StridedSlice"(%508, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %510 = "tf.Pack"(%499) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %511 = "tf.StridedSlice"(%507, %510, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %512 = "tf.Sub"(%509, %499) {device = ""} : (tensor, tensor) -> tensor - %513 = "tf.Pack"(%512) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %514 = "tf.StridedSlice"(%507, %11, %513, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %515:2 = "tf.RaggedRange"(%514, %511, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %516 = "tf.Select"(%2, %492, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %517 = "tf.Pack"(%516, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %518 = "tf.StridedSlice"(%517, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %519 = "tf.Cast"(%518) {Truncate = false, device = ""} : (tensor) -> tensor - %520 = "tf.Reshape"(%519, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %521 = "tf.Pack"(%7, %520) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %522 = "tf.Tile"(%4, %521) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %523 = "tf.Mul"(%520, %8) {device = ""} : (tensor, tensor) -> tensor - %524 = "tf.Pack"(%523) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %525 = "tf.ConcatV2"(%9, %524, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %526 = "tf.Reshape"(%522, %525) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %527 = "tf.Shape"(%526) {device = ""} : (tensor) -> tensor<1xi64> - %528 = "tf.StridedSlice"(%527, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %529 = "tf.Pack"(%518) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %530 = "tf.StridedSlice"(%526, %529, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %531 = "tf.Sub"(%528, %518) {device = ""} : (tensor, tensor) -> tensor - %532 = "tf.Pack"(%531) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %533 = "tf.StridedSlice"(%526, %11, %532, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %534:2 = "tf.RaggedRange"(%533, %530, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %535 = "tf.StridedSlice"(%517, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> - %536 = "tf.StridedSlice"(%517, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %537 = "tf.Mul"(%536, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> - %538 = "tf.Tile"(%537, %535) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor - %539 = "tf.Cumsum"(%538, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor - %540 = "tf.ConcatV2"(%11, %539, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor - %541 = "tf.StridedSlice"(%540, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %542 = "tf.ExpandDims"(%541, %7) {device = ""} : (tensor, tensor) -> tensor - %543 = "tf.Shape"(%541) {device = ""} : (tensor) -> tensor<1xi32> - %544 = "tf.StridedSlice"(%543, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %545 = "tf.Pack"(%544) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %546 = "tf.StridedSlice"(%540, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %547 = "tf.ExpandDims"(%546, %7) {device = ""} : (tensor, tensor) -> tensor - %548 = "tf.Shape"(%546) {device = ""} : (tensor) -> tensor<1xi32> - %549 = "tf.StridedSlice"(%548, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %550 = "tf.Pack"(%549) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %551 = "tf.Equal"(%348, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %552 = "tf.Select"(%551, %492, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %553 = "tf.Cast"(%552) {Truncate = false, device = ""} : (tensor) -> tensor - %554 = "tf.Reshape"(%553, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor - %555 = "tf.Pack"(%7, %554) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %556 = "tf.Mul"(%554, %8) {device = ""} : (tensor, tensor) -> tensor - %557 = "tf.Pack"(%556) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %558 = "tf.ConcatV2"(%9, %557, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %559 = "tf.Pack"(%552) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %560 = "tf.Pack"(%10, %348) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> - %561 = "tf.ExpandDims"(%560, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> - %562 = "tf.Tile"(%561, %555) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> - %563 = "tf.Reshape"(%562, %558) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor - %564 = "tf.Shape"(%563) {device = ""} : (tensor) -> tensor<1xi64> - %565 = "tf.StridedSlice"(%564, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %566 = "tf.Sub"(%565, %552) {device = ""} : (tensor, tensor) -> tensor - %567 = "tf.Pack"(%566) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %568 = "tf.StridedSlice"(%563, %11, %567, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %569 = "tf.StridedSlice"(%563, %559, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor - %570:2 = "tf.RaggedRange"(%568, %569, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %571 = "tf.GatherV2"(%355, %570#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %572 = "tf.Cast"(%571) {Truncate = false, device = ""} : (tensor) -> tensor - %573 = "tf.BroadcastTo"(%572, %545) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %574 = "tf.Max"(%573, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %575 = "tf.Maximum"(%14, %574) {device = ""} : (tensor, tensor) -> tensor - %576 = "tf.Range"(%14, %575, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %577 = "tf.Pack"(%7, %575) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %578 = "tf.Tile"(%542, %577) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %579 = "tf.Shape"(%578) {device = ""} : (tensor) -> tensor<2xi32> - %580 = "tf.StridedSlice"(%579, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %581 = "tf.Prod"(%580, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %582 = "tf.Pack"(%581) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %583 = "tf.Shape"(%578) {device = ""} : (tensor) -> tensor<2xi32> - %584 = "tf.StridedSlice"(%583, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %585 = "tf.Shape"(%578) {device = ""} : (tensor) -> tensor<2xi32> - %586 = "tf.StridedSlice"(%585, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %587 = "tf.ConcatV2"(%584, %582, %586, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %588 = "tf.Reshape"(%578, %587) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %589 = "tf.ExpandDims"(%573, %3) {device = ""} : (tensor, tensor) -> tensor - %590 = "tf.Less"(%576, %589) {device = ""} : (tensor, tensor) -> tensor - %591 = "tf.Reshape"(%590, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %592 = "tf.Where"(%591) {device = ""} : (tensor) -> tensor - %593 = "tf.Squeeze"(%592) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %594 = "tf.GatherV2"(%588, %593, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %595 = "tf.Cast"(%571) {Truncate = false, device = ""} : (tensor) -> tensor - %596 = "tf.BroadcastTo"(%595, %550) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %597 = "tf.Max"(%596, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor - %598 = "tf.Maximum"(%14, %597) {device = ""} : (tensor, tensor) -> tensor - %599 = "tf.Range"(%14, %598, %7) {device = ""} : (tensor, tensor, tensor) -> tensor - %600 = "tf.Pack"(%7, %598) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> - %601 = "tf.Tile"(%547, %600) {device = ""} : (tensor, tensor<2xi32>) -> tensor - %602 = "tf.Shape"(%601) {device = ""} : (tensor) -> tensor<2xi32> - %603 = "tf.StridedSlice"(%602, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - %604 = "tf.Prod"(%603, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor - %605 = "tf.Pack"(%604) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %606 = "tf.Shape"(%601) {device = ""} : (tensor) -> tensor<2xi32> - %607 = "tf.StridedSlice"(%606, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %608 = "tf.Shape"(%601) {device = ""} : (tensor) -> tensor<2xi32> - %609 = "tf.StridedSlice"(%608, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %610 = "tf.ConcatV2"(%607, %605, %609, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> - %611 = "tf.Reshape"(%601, %610) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %612 = "tf.ExpandDims"(%596, %3) {device = ""} : (tensor, tensor) -> tensor - %613 = "tf.Less"(%599, %612) {device = ""} : (tensor, tensor) -> tensor - %614 = "tf.Reshape"(%613, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor - %615 = "tf.Where"(%614) {device = ""} : (tensor) -> tensor - %616 = "tf.Squeeze"(%615) {device = "", squeeze_dims = [1]} : (tensor) -> tensor - %617 = "tf.GatherV2"(%611, %616, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %618:2 = "tf.RaggedRange"(%594, %617, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) - %619 = "tf.If"(%496, %496, %492, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_22980, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_22970} : (tensor, tensor, tensor, tensor) -> tensor - %620 = "tf.Identity"(%619) {device = ""} : (tensor) -> tensor - %621 = "tf.Select"(%2, %492, %13) {device = ""} : (tensor, tensor, tensor) -> tensor - %622 = "tf.Pack"(%621) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> - %623 = "tf.ConcatV2"(%1, %622, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> - %624 = "tf.StridedSlice"(%623, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %625 = "tf.Equal"(%624, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %626 = "tf.StridedSlice"(%623, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %627 = "tf.StridedSlice"(%623, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %628 = "tf.Equal"(%627, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor - %629 = "tf.If"(%628, %628, %627, %571) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_23470, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_23460} : (tensor, tensor, tensor, tensor) -> tensor - %630 = "tf.Identity"(%629) {device = ""} : (tensor) -> tensor - %631 = "tf.If"(%625, %625, %571, %626) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_23830, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23820} : (tensor, tensor, tensor, tensor) -> tensor - %632 = "tf.Identity"(%631) {device = ""} : (tensor) -> tensor - %633 = "tf.Identity"(%308) {device = ""} : (tensor) -> tensor - %634 = "tf.Shape"(%37#2) {device = ""} : (tensor) -> tensor<1xi32> - %635 = "tf.StridedSlice"(%634, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %636 = "tf.Cast"(%635) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> - %637 = "tf.Identity"(%636) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> - %638 = "tf.Shape"(%37#3) {device = ""} : (tensor) -> tensor<1xi32> - %639 = "tf.StridedSlice"(%638, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %640 = "tf.Cast"(%639) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> - %641 = "tf.Identity"(%640) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> - %642 = "tf.GatherV2"(%37#3, %336, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor - %643 = "tf.Tile"(%642, %340) {device = ""} : (tensor, tensor<1xi64>) -> tensor - %644 = "tf.Sub"(%643, %474) {device = ""} : (tensor, tensor) -> tensor - %645 = "tf.Shape"(%644) {device = ""} : (tensor) -> tensor<1xi32> - %646 = "tf.StridedSlice"(%645, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> - %647 = "tf.Cast"(%646) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> - %648 = "tf.Identity"(%647) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> - %649 = "tf.UnicodeEncode"(%37#0, %58) {Tsplits = i64, device = "", errors = "replace", output_encoding = "UTF-8", replacement_char = 65533 : i64} : (tensor, tensor) -> tensor - %650 = "tf.Identity"(%649) {device = ""} : (tensor) -> tensor - return %650 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_3220(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Input tensors have incompatible shapes."> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedConcat/RaggedFromTensor/Const:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedConcat/RaggedNRows/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_3210(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_3980(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_3970(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_4340(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_4330(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_4680(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_4670(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_5050(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_5040(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_5410(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_5400(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_5770(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RaggedNRows/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_5760(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6120(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6110(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_6480(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_6470(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6820(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6810(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7190(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7180(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_7550(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () - %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor - return %4 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_7540(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7890(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7880(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_8690(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_8680(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_9760(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_9750(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_10250(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_10240(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_10610(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_10600(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_false_15310(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_true_15300(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_false_16380(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_true_16370(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_false_16870(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_true_16860(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_false_17230(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_true_17220(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_21910(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_21900(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_22980(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_22970(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_23470(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_23460(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_23830(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { - %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor - %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor - "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () - %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor - return %5 : tensor - } - func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23820(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { - %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor - %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor - return %1 : tensor - } - - // CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor attributes {tf._input_shapes = [#tf.shape<>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} { - // CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor) -> tensor - // CHECK: return %0 : tensor +func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor, tensor) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> + %1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64> + %2 = "tf.Const"() {value = dense : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %4 = "tf.Const"() {value = dense<[[0], [1]]> : tensor<2x1xi64>} : () -> tensor<2x1xi64> + %5 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %6 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %7 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %8 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %9 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %10 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %11 = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64> + %12 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> + %13 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %14 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %15 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %16 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %17 = "tf.If"(%2, %2, %13, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_3210, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_3200} : (tensor, tensor, tensor, tensor) -> tensor + %18 = "tf.Identity"(%17) {device = ""} : (tensor) -> tensor + %19 = "tf.StringLength"(%arg0) {device = "", unit = "BYTE"} : (tensor<1x!tf.string>) -> tensor<1xi32> + %20 = "tf.ExpandDims"(%19, %7) {device = ""} : (tensor<1xi32>, tensor) -> tensor<1x1xi32> + %21 = "tf.Cast"(%20) {Truncate = false, device = ""} : (tensor<1x1xi32>) -> tensor<1x1xi64> + %22 = "tf.Reshape"(%21, %12) {device = ""} : (tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64> + %23 = "tf.Reshape"(%arg0, %5) {device = ""} : (tensor<1x!tf.string>, tensor<1xi32>) -> tensor<1x!tf.string> + %24:3 = "tf.UnicodeDecodeWithOffsets"(%23) {Tsplits = i64, device = "", errors = "replace", input_encoding = "UTF-8", replace_control_characters = false, replacement_char = 65533 : i64} : (tensor<1x!tf.string>) -> (tensor<2xi64>, tensor, tensor) + %25 = "tf.StridedSlice"(%24#0, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %26 = "tf.AddV2"(%25, %13) {device = ""} : (tensor<1xi64>, tensor) -> tensor<1xi64> + %27 = "tf.StridedSlice"(%24#0, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %28 = "tf.Minimum"(%26, %27) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> + %29:2 = "tf.RaggedRange"(%28, %27, %13) {T = i64, Tsplits = i64, device = ""} : (tensor<1xi64>, tensor<1xi64>, tensor) -> (tensor<2xi64>, tensor) + %30 = "tf.StridedSlice"(%29#0, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %31 = "tf.AddV2"(%30, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> + %32 = "tf.ConcatV2"(%29#0, %31, %14) {device = ""} : (tensor<2xi64>, tensor<1xi64>, tensor) -> tensor<3xi64> + %33 = "tf.GatherV2"(%24#2, %29#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %34 = "tf.ConcatV2"(%33, %22, %14) {device = ""} : (tensor, tensor<1xi64>, tensor) -> tensor + %35:2 = "tf.RaggedGather"(%32, %34, %0) {OUTPUT_RAGGED_RANK = 1 : i64, PARAMS_RAGGED_RANK = 1 : i64, Tindices = i64, Tsplits = i64, Tvalues = i64, device = ""} : (tensor<3xi64>, tensor, tensor<2xi64>) -> (tensor, tensor) + %36:5 = "tf.WhitespaceTokenizeWithOffsets"(%24#1, %24#0) {Tsplits = i64, device = ""} : (tensor, tensor<2xi64>) -> (tensor, tensor, tensor, tensor, tensor) + %37 = "tf.StridedSlice"(%36#1, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %38 = "tf.Equal"(%37, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %39 = "tf.All"(%38, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %40 = "tf.If"(%39, %39, %37, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_3970, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_3960} : (tensor, tensor, tensor, tensor) -> tensor + %41 = "tf.Identity"(%40) {device = ""} : (tensor) -> tensor + %42 = "tf.StridedSlice"(%36#1, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %43 = "tf.StridedSlice"(%36#1, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %44 = "tf.Sub"(%42, %43) {device = ""} : (tensor, tensor) -> tensor + %45 = "tf.LessEqual"(%10, %44) {device = ""} : (tensor, tensor) -> tensor + %46 = "tf.All"(%45, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %47 = "tf.If"(%46, %46, %44) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_4330, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_4320} : (tensor, tensor, tensor) -> tensor + %48 = "tf.Identity"(%47) {device = ""} : (tensor) -> tensor + %49 = "tf.Identity"(%36#1) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %50 = "tf.StridedSlice"(%49, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %51 = "tf.Shape"(%36#0) {device = ""} : (tensor) -> tensor<1xi64> + %52 = "tf.StridedSlice"(%51, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %53 = "tf.Equal"(%50, %52) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %54 = "tf.All"(%53, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %55 = "tf.If"(%54, %54, %50, %52) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_4670, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_4660} : (tensor, tensor, tensor, tensor) -> tensor + %56 = "tf.Identity"(%55) {device = ""} : (tensor) -> tensor + %57 = "tf.Identity"(%49) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %58 = "tf.Shape"(%57) {device = ""} : (tensor) -> tensor<1xi64> + %59 = "tf.StridedSlice"(%58, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %60 = "tf.Sub"(%59, %13) {device = ""} : (tensor, tensor) -> tensor + %61 = "tf.StridedSlice"(%36#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %62 = "tf.Equal"(%61, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %63 = "tf.All"(%62, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %64 = "tf.If"(%63, %63, %61, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_5040, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_5030} : (tensor, tensor, tensor, tensor) -> tensor + %65 = "tf.Identity"(%64) {device = ""} : (tensor) -> tensor + %66 = "tf.StridedSlice"(%36#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %67 = "tf.StridedSlice"(%36#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %68 = "tf.Sub"(%66, %67) {device = ""} : (tensor, tensor) -> tensor + %69 = "tf.LessEqual"(%10, %68) {device = ""} : (tensor, tensor) -> tensor + %70 = "tf.All"(%69, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %71 = "tf.If"(%70, %70, %68) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_5400, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_5390} : (tensor, tensor, tensor) -> tensor + %72 = "tf.Identity"(%71) {device = ""} : (tensor) -> tensor + %73 = "tf.Identity"(%36#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %74 = "tf.StridedSlice"(%73, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %75 = "tf.Equal"(%74, %60) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %76 = "tf.All"(%75, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %77 = "tf.If"(%76, %76, %74, %60) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_5760, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_5750} : (tensor, tensor, tensor, tensor) -> tensor + %78 = "tf.Identity"(%77) {device = ""} : (tensor) -> tensor + %79 = "tf.Identity"(%73) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %80 = "tf.StridedSlice"(%36#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %81 = "tf.Equal"(%80, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %82 = "tf.All"(%81, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %83 = "tf.If"(%82, %82, %80, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6110, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6100} : (tensor, tensor, tensor, tensor) -> tensor + %84 = "tf.Identity"(%83) {device = ""} : (tensor) -> tensor + %85 = "tf.StridedSlice"(%36#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %86 = "tf.StridedSlice"(%36#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %87 = "tf.Sub"(%85, %86) {device = ""} : (tensor, tensor) -> tensor + %88 = "tf.LessEqual"(%10, %87) {device = ""} : (tensor, tensor) -> tensor + %89 = "tf.All"(%88, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %90 = "tf.If"(%89, %89, %87) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_6470, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_6460} : (tensor, tensor, tensor) -> tensor + %91 = "tf.Identity"(%90) {device = ""} : (tensor) -> tensor + %92 = "tf.Identity"(%36#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %93 = "tf.StridedSlice"(%92, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %94 = "tf.Shape"(%36#2) {device = ""} : (tensor) -> tensor<1xi64> + %95 = "tf.StridedSlice"(%94, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %96 = "tf.Equal"(%93, %95) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %97 = "tf.All"(%96, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %98 = "tf.If"(%97, %97, %93, %95) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6810, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6800} : (tensor, tensor, tensor, tensor) -> tensor + %99 = "tf.Identity"(%98) {device = ""} : (tensor) -> tensor + %100 = "tf.Identity"(%92) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %101 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<1xi64> + %102 = "tf.StridedSlice"(%101, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %103 = "tf.Sub"(%102, %13) {device = ""} : (tensor, tensor) -> tensor + %104 = "tf.Equal"(%103, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %105 = "tf.LogicalOr"(%104, %2) {device = ""} : (tensor, tensor) -> tensor + %106 = "tf.Equal"(%103, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %107 = "tf.LogicalOr"(%105, %106) {device = ""} : (tensor, tensor) -> tensor + %108 = "tf.StridedSlice"(%100, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %109 = "tf.StridedSlice"(%100, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %110 = "tf.Sub"(%108, %109) {device = ""} : (tensor, tensor) -> tensor + %111 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<1xi64> + %112 = "tf.StridedSlice"(%111, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %113 = "tf.Sub"(%112, %13) {device = ""} : (tensor, tensor) -> tensor + %114 = "tf.Equal"(%113, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %115 = "tf.ExpandDims"(%100, %7) {device = ""} : (tensor, tensor) -> tensor + %116 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<1xi32> + %117 = "tf.StridedSlice"(%116, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %118 = "tf.StridedSlice"(%116, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %119 = "tf.StridedSlice"(%116, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %120 = "tf.StridedSlice"(%36#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %121 = "tf.Equal"(%120, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %122 = "tf.All"(%121, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %123 = "tf.If"(%122, %122, %120, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7180, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7170} : (tensor, tensor, tensor, tensor) -> tensor + %124 = "tf.Identity"(%123) {device = ""} : (tensor) -> tensor + %125 = "tf.StridedSlice"(%36#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %126 = "tf.StridedSlice"(%36#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %127 = "tf.Sub"(%125, %126) {device = ""} : (tensor, tensor) -> tensor + %128 = "tf.LessEqual"(%10, %127) {device = ""} : (tensor, tensor) -> tensor + %129 = "tf.All"(%128, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %130 = "tf.If"(%129, %129, %127) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_7540, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_7530} : (tensor, tensor, tensor) -> tensor + %131 = "tf.Identity"(%130) {device = ""} : (tensor) -> tensor + %132 = "tf.Identity"(%36#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %133 = "tf.StridedSlice"(%132, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %134 = "tf.Shape"(%36#3) {device = ""} : (tensor) -> tensor<1xi64> + %135 = "tf.StridedSlice"(%134, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %136 = "tf.Equal"(%133, %135) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %137 = "tf.All"(%136, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %138 = "tf.If"(%137, %137, %133, %135) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7880, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7870} : (tensor, tensor, tensor, tensor) -> tensor + %139 = "tf.Identity"(%138) {device = ""} : (tensor) -> tensor + %140 = "tf.Identity"(%132) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %141 = "tf.Shape"(%140) {device = ""} : (tensor) -> tensor<1xi64> + %142 = "tf.StridedSlice"(%141, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %143 = "tf.Sub"(%142, %13) {device = ""} : (tensor, tensor) -> tensor + %144 = "tf.Equal"(%143, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %145 = "tf.LogicalOr"(%144, %2) {device = ""} : (tensor, tensor) -> tensor + %146 = "tf.Equal"(%143, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %147 = "tf.LogicalOr"(%145, %146) {device = ""} : (tensor, tensor) -> tensor + %148 = "tf.StridedSlice"(%140, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %149 = "tf.StridedSlice"(%140, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %150 = "tf.Sub"(%148, %149) {device = ""} : (tensor, tensor) -> tensor + %151 = "tf.Shape"(%140) {device = ""} : (tensor) -> tensor<1xi64> + %152 = "tf.StridedSlice"(%151, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %153 = "tf.Sub"(%152, %13) {device = ""} : (tensor, tensor) -> tensor + %154 = "tf.Equal"(%153, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %155 = "tf.ExpandDims"(%140, %7) {device = ""} : (tensor, tensor) -> tensor + %156 = "tf.Shape"(%140) {device = ""} : (tensor) -> tensor<1xi32> + %157 = "tf.StridedSlice"(%156, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %158 = "tf.StridedSlice"(%156, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %159 = "tf.StridedSlice"(%156, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %160 = "tf.StridedSlice"(%140, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %161 = "tf.Range"(%10, %160, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %162 = "tf.StridedSlice"(%140, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %163 = "tf.StridedSlice"(%140, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %164 = "tf.Sub"(%162, %163) {device = ""} : (tensor, tensor) -> tensor + %165 = "tf.If"(%107, %107, %13, %103) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_8680, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_8670} : (tensor, tensor, tensor, tensor) -> tensor + %166 = "tf.Identity"(%165) {device = ""} : (tensor) -> tensor + %167 = "tf.Equal"(%103, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %168 = "tf.Select"(%167, %13, %103) {device = ""} : (tensor, tensor, tensor) -> tensor + %169 = "tf.Equal"(%168, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %170 = "tf.LogicalOr"(%169, %2) {device = ""} : (tensor, tensor) -> tensor + %171 = "tf.Equal"(%168, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %172 = "tf.LogicalOr"(%170, %171) {device = ""} : (tensor, tensor) -> tensor + %173 = "tf.Select"(%114, %168, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %174 = "tf.Pack"(%173, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %175 = "tf.StridedSlice"(%174, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %176 = "tf.Cast"(%175) {Truncate = false, device = ""} : (tensor) -> tensor + %177 = "tf.Reshape"(%176, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %178 = "tf.Pack"(%7, %177) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %179 = "tf.Tile"(%115, %178) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %180 = "tf.Mul"(%177, %118) {device = ""} : (tensor, tensor) -> tensor + %181 = "tf.Pack"(%180) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %182 = "tf.ConcatV2"(%117, %181, %119, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %183 = "tf.Reshape"(%179, %182) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %184 = "tf.Shape"(%183) {device = ""} : (tensor) -> tensor<1xi64> + %185 = "tf.StridedSlice"(%184, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %186 = "tf.Pack"(%175) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %187 = "tf.StridedSlice"(%183, %186, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %188 = "tf.Sub"(%185, %175) {device = ""} : (tensor, tensor) -> tensor + %189 = "tf.Pack"(%188) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %190 = "tf.StridedSlice"(%183, %11, %189, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %191:2 = "tf.RaggedRange"(%190, %187, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %192 = "tf.Select"(%2, %168, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %193 = "tf.Pack"(%192, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %194 = "tf.StridedSlice"(%193, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %195 = "tf.Cast"(%194) {Truncate = false, device = ""} : (tensor) -> tensor + %196 = "tf.Reshape"(%195, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %197 = "tf.Pack"(%7, %196) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %198 = "tf.Tile"(%4, %197) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %199 = "tf.Mul"(%196, %8) {device = ""} : (tensor, tensor) -> tensor + %200 = "tf.Pack"(%199) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %201 = "tf.ConcatV2"(%9, %200, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %202 = "tf.Reshape"(%198, %201) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %203 = "tf.Shape"(%202) {device = ""} : (tensor) -> tensor<1xi64> + %204 = "tf.StridedSlice"(%203, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %205 = "tf.Pack"(%194) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %206 = "tf.StridedSlice"(%202, %205, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %207 = "tf.Sub"(%204, %194) {device = ""} : (tensor, tensor) -> tensor + %208 = "tf.Pack"(%207) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %209 = "tf.StridedSlice"(%202, %11, %208, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %210:2 = "tf.RaggedRange"(%209, %206, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %211 = "tf.StridedSlice"(%193, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %212 = "tf.StridedSlice"(%193, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %213 = "tf.Mul"(%212, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> + %214 = "tf.Tile"(%213, %211) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor + %215 = "tf.Cumsum"(%214, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %216 = "tf.ConcatV2"(%11, %215, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %217 = "tf.StridedSlice"(%216, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %218 = "tf.ExpandDims"(%217, %7) {device = ""} : (tensor, tensor) -> tensor + %219 = "tf.Shape"(%217) {device = ""} : (tensor) -> tensor<1xi32> + %220 = "tf.StridedSlice"(%219, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %221 = "tf.Pack"(%220) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %222 = "tf.StridedSlice"(%216, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %223 = "tf.ExpandDims"(%222, %7) {device = ""} : (tensor, tensor) -> tensor + %224 = "tf.Shape"(%222) {device = ""} : (tensor) -> tensor<1xi32> + %225 = "tf.StridedSlice"(%224, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %226 = "tf.Pack"(%225) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %227 = "tf.Equal"(%103, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %228 = "tf.Select"(%227, %168, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %229 = "tf.Cast"(%228) {Truncate = false, device = ""} : (tensor) -> tensor + %230 = "tf.Reshape"(%229, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %231 = "tf.Pack"(%7, %230) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %232 = "tf.Mul"(%230, %8) {device = ""} : (tensor, tensor) -> tensor + %233 = "tf.Pack"(%232) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %234 = "tf.ConcatV2"(%9, %233, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %235 = "tf.Pack"(%228) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %236 = "tf.Pack"(%10, %103) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %237 = "tf.ExpandDims"(%236, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> + %238 = "tf.Tile"(%237, %231) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %239 = "tf.Reshape"(%238, %234) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %240 = "tf.Shape"(%239) {device = ""} : (tensor) -> tensor<1xi64> + %241 = "tf.StridedSlice"(%240, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %242 = "tf.Sub"(%241, %228) {device = ""} : (tensor, tensor) -> tensor + %243 = "tf.Pack"(%242) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %244 = "tf.StridedSlice"(%239, %11, %243, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %245 = "tf.StridedSlice"(%239, %235, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %246:2 = "tf.RaggedRange"(%244, %245, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %247 = "tf.GatherV2"(%110, %246#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %248 = "tf.Cast"(%247) {Truncate = false, device = ""} : (tensor) -> tensor + %249 = "tf.BroadcastTo"(%248, %221) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %250 = "tf.Max"(%249, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %251 = "tf.Maximum"(%14, %250) {device = ""} : (tensor, tensor) -> tensor + %252 = "tf.Range"(%14, %251, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %253 = "tf.Pack"(%7, %251) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %254 = "tf.Tile"(%218, %253) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %255 = "tf.Shape"(%254) {device = ""} : (tensor) -> tensor<2xi32> + %256 = "tf.StridedSlice"(%255, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %257 = "tf.Prod"(%256, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %258 = "tf.Pack"(%257) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %259 = "tf.Shape"(%254) {device = ""} : (tensor) -> tensor<2xi32> + %260 = "tf.StridedSlice"(%259, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %261 = "tf.Shape"(%254) {device = ""} : (tensor) -> tensor<2xi32> + %262 = "tf.StridedSlice"(%261, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %263 = "tf.ConcatV2"(%260, %258, %262, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %264 = "tf.Reshape"(%254, %263) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %265 = "tf.ExpandDims"(%249, %3) {device = ""} : (tensor, tensor) -> tensor + %266 = "tf.Less"(%252, %265) {device = ""} : (tensor, tensor) -> tensor + %267 = "tf.Reshape"(%266, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %268 = "tf.Where"(%267) {device = ""} : (tensor) -> tensor + %269 = "tf.Squeeze"(%268) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %270 = "tf.GatherV2"(%264, %269, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %271 = "tf.Cast"(%247) {Truncate = false, device = ""} : (tensor) -> tensor + %272 = "tf.BroadcastTo"(%271, %226) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %273 = "tf.Max"(%272, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %274 = "tf.Maximum"(%14, %273) {device = ""} : (tensor, tensor) -> tensor + %275 = "tf.Range"(%14, %274, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %276 = "tf.Pack"(%7, %274) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %277 = "tf.Tile"(%223, %276) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %278 = "tf.Shape"(%277) {device = ""} : (tensor) -> tensor<2xi32> + %279 = "tf.StridedSlice"(%278, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %280 = "tf.Prod"(%279, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %281 = "tf.Pack"(%280) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %282 = "tf.Shape"(%277) {device = ""} : (tensor) -> tensor<2xi32> + %283 = "tf.StridedSlice"(%282, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %284 = "tf.Shape"(%277) {device = ""} : (tensor) -> tensor<2xi32> + %285 = "tf.StridedSlice"(%284, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %286 = "tf.ConcatV2"(%283, %281, %285, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %287 = "tf.Reshape"(%277, %286) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %288 = "tf.ExpandDims"(%272, %3) {device = ""} : (tensor, tensor) -> tensor + %289 = "tf.Less"(%275, %288) {device = ""} : (tensor, tensor) -> tensor + %290 = "tf.Reshape"(%289, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %291 = "tf.Where"(%290) {device = ""} : (tensor) -> tensor + %292 = "tf.Squeeze"(%291) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %293 = "tf.GatherV2"(%287, %292, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %294:2 = "tf.RaggedRange"(%270, %293, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %295 = "tf.If"(%172, %172, %168, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_9750, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_9740} : (tensor, tensor, tensor, tensor) -> tensor + %296 = "tf.Identity"(%295) {device = ""} : (tensor) -> tensor + %297 = "tf.Select"(%2, %168, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %298 = "tf.Pack"(%297) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %299 = "tf.ConcatV2"(%1, %298, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> + %300 = "tf.StridedSlice"(%299, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %301 = "tf.Equal"(%300, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %302 = "tf.StridedSlice"(%299, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %303 = "tf.StridedSlice"(%299, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %304 = "tf.Equal"(%303, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %305 = "tf.If"(%304, %304, %303, %247) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_10240, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_10230} : (tensor, tensor, tensor, tensor) -> tensor + %306 = "tf.Identity"(%305) {device = ""} : (tensor) -> tensor + %307 = "tf.If"(%301, %301, %247, %302) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_10600, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_10590} : (tensor, tensor, tensor, tensor) -> tensor + %308 = "tf.If"(%147, %147, %13, %143) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_Assert_AssertGuard_false_15300, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_Assert_AssertGuard_true_15290} : (tensor, tensor, tensor, tensor) -> tensor + %309 = "tf.Identity"(%308) {device = ""} : (tensor) -> tensor + %310 = "tf.Equal"(%143, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %311 = "tf.Select"(%310, %13, %143) {device = ""} : (tensor, tensor, tensor) -> tensor + %312 = "tf.Equal"(%311, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %313 = "tf.LogicalOr"(%312, %2) {device = ""} : (tensor, tensor) -> tensor + %314 = "tf.Equal"(%311, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %315 = "tf.LogicalOr"(%313, %314) {device = ""} : (tensor, tensor) -> tensor + %316 = "tf.Select"(%154, %311, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %317 = "tf.Pack"(%316, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %318 = "tf.StridedSlice"(%317, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %319 = "tf.Cast"(%318) {Truncate = false, device = ""} : (tensor) -> tensor + %320 = "tf.Reshape"(%319, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %321 = "tf.Pack"(%7, %320) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %322 = "tf.Tile"(%155, %321) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %323 = "tf.Mul"(%320, %158) {device = ""} : (tensor, tensor) -> tensor + %324 = "tf.Pack"(%323) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %325 = "tf.ConcatV2"(%157, %324, %159, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %326 = "tf.Reshape"(%322, %325) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %327 = "tf.Shape"(%326) {device = ""} : (tensor) -> tensor<1xi64> + %328 = "tf.StridedSlice"(%327, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %329 = "tf.Pack"(%318) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %330 = "tf.StridedSlice"(%326, %329, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %331 = "tf.Sub"(%328, %318) {device = ""} : (tensor, tensor) -> tensor + %332 = "tf.Pack"(%331) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %333 = "tf.StridedSlice"(%326, %11, %332, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %334:2 = "tf.RaggedRange"(%333, %330, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %335 = "tf.GatherV2"(%161, %334#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %336 = "tf.StridedSlice"(%317, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %337 = "tf.StridedSlice"(%317, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %338 = "tf.StridedSlice"(%317, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> + %339 = "tf.ConcatV2"(%337, %338, %14) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> + %340 = "tf.StridedSlice"(%317, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %341 = "tf.Mul"(%164, %340) {device = ""} : (tensor, tensor) -> tensor + %342 = "tf.Tile"(%341, %336) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %343 = "tf.Cumsum"(%342, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %344 = "tf.ConcatV2"(%11, %343, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %345 = "tf.Shape"(%344) {device = ""} : (tensor) -> tensor<1xi64> + %346 = "tf.StridedSlice"(%345, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %347 = "tf.Sub"(%346, %13) {device = ""} : (tensor, tensor) -> tensor + %348 = "tf.Equal"(%347, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %349 = "tf.LogicalOr"(%348, %2) {device = ""} : (tensor, tensor) -> tensor + %350 = "tf.Equal"(%347, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %351 = "tf.LogicalOr"(%349, %350) {device = ""} : (tensor, tensor) -> tensor + %352 = "tf.StridedSlice"(%344, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %353 = "tf.StridedSlice"(%344, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %354 = "tf.Sub"(%352, %353) {device = ""} : (tensor, tensor) -> tensor + %355 = "tf.Shape"(%344) {device = ""} : (tensor) -> tensor<1xi64> + %356 = "tf.StridedSlice"(%355, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %357 = "tf.Sub"(%356, %13) {device = ""} : (tensor, tensor) -> tensor + %358 = "tf.Equal"(%357, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %359 = "tf.ExpandDims"(%344, %7) {device = ""} : (tensor, tensor) -> tensor + %360 = "tf.Shape"(%344) {device = ""} : (tensor) -> tensor<1xi32> + %361 = "tf.StridedSlice"(%360, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %362 = "tf.StridedSlice"(%360, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %363 = "tf.StridedSlice"(%360, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %364 = "tf.Select"(%2, %311, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %365 = "tf.Pack"(%364, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %366 = "tf.StridedSlice"(%365, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %367 = "tf.Cast"(%366) {Truncate = false, device = ""} : (tensor) -> tensor + %368 = "tf.Reshape"(%367, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %369 = "tf.Pack"(%7, %368) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %370 = "tf.Tile"(%4, %369) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %371 = "tf.Mul"(%368, %8) {device = ""} : (tensor, tensor) -> tensor + %372 = "tf.Pack"(%371) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %373 = "tf.ConcatV2"(%9, %372, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %374 = "tf.Reshape"(%370, %373) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %375 = "tf.Shape"(%374) {device = ""} : (tensor) -> tensor<1xi64> + %376 = "tf.StridedSlice"(%375, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %377 = "tf.Pack"(%366) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %378 = "tf.StridedSlice"(%374, %377, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %379 = "tf.Sub"(%376, %366) {device = ""} : (tensor, tensor) -> tensor + %380 = "tf.Pack"(%379) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %381 = "tf.StridedSlice"(%374, %11, %380, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %382:2 = "tf.RaggedRange"(%381, %378, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %383 = "tf.GatherV2"(%11, %382#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %384 = "tf.GatherV2"(%12, %383, %14) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %385 = "tf.StridedSlice"(%365, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %386 = "tf.StridedSlice"(%365, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %387 = "tf.StridedSlice"(%365, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> + %388 = "tf.ConcatV2"(%386, %387, %14) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> + %389 = "tf.Tile"(%384, %388) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %390 = "tf.StridedSlice"(%365, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %391 = "tf.Mul"(%390, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> + %392 = "tf.Tile"(%391, %385) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor + %393 = "tf.Cumsum"(%392, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %394 = "tf.ConcatV2"(%11, %393, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %395 = "tf.StridedSlice"(%394, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %396 = "tf.ExpandDims"(%395, %7) {device = ""} : (tensor, tensor) -> tensor + %397 = "tf.Shape"(%395) {device = ""} : (tensor) -> tensor<1xi32> + %398 = "tf.StridedSlice"(%397, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %399 = "tf.Pack"(%398) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %400 = "tf.StridedSlice"(%394, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %401 = "tf.ExpandDims"(%400, %7) {device = ""} : (tensor, tensor) -> tensor + %402 = "tf.Shape"(%400) {device = ""} : (tensor) -> tensor<1xi32> + %403 = "tf.StridedSlice"(%402, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %404 = "tf.Pack"(%403) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %405 = "tf.Equal"(%143, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %406 = "tf.Select"(%405, %311, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %407 = "tf.Cast"(%406) {Truncate = false, device = ""} : (tensor) -> tensor + %408 = "tf.Reshape"(%407, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %409 = "tf.Pack"(%7, %408) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %410 = "tf.Mul"(%408, %8) {device = ""} : (tensor, tensor) -> tensor + %411 = "tf.Pack"(%410) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %412 = "tf.ConcatV2"(%9, %411, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %413 = "tf.Pack"(%406) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %414 = "tf.Pack"(%10, %143) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %415 = "tf.ExpandDims"(%414, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> + %416 = "tf.Tile"(%415, %409) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %417 = "tf.Reshape"(%416, %412) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %418 = "tf.Shape"(%417) {device = ""} : (tensor) -> tensor<1xi64> + %419 = "tf.StridedSlice"(%418, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %420 = "tf.Sub"(%419, %406) {device = ""} : (tensor, tensor) -> tensor + %421 = "tf.Pack"(%420) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %422 = "tf.StridedSlice"(%417, %11, %421, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %423 = "tf.StridedSlice"(%417, %413, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %424:2 = "tf.RaggedRange"(%422, %423, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %425 = "tf.GatherV2"(%150, %424#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %426 = "tf.Cast"(%425) {Truncate = false, device = ""} : (tensor) -> tensor + %427 = "tf.BroadcastTo"(%426, %399) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %428 = "tf.Max"(%427, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %429 = "tf.Maximum"(%14, %428) {device = ""} : (tensor, tensor) -> tensor + %430 = "tf.Range"(%14, %429, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %431 = "tf.Pack"(%7, %429) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %432 = "tf.Tile"(%396, %431) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %433 = "tf.Shape"(%432) {device = ""} : (tensor) -> tensor<2xi32> + %434 = "tf.StridedSlice"(%433, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %435 = "tf.Prod"(%434, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %436 = "tf.Pack"(%435) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %437 = "tf.Shape"(%432) {device = ""} : (tensor) -> tensor<2xi32> + %438 = "tf.StridedSlice"(%437, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %439 = "tf.Shape"(%432) {device = ""} : (tensor) -> tensor<2xi32> + %440 = "tf.StridedSlice"(%439, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %441 = "tf.ConcatV2"(%438, %436, %440, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %442 = "tf.Reshape"(%432, %441) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %443 = "tf.ExpandDims"(%427, %3) {device = ""} : (tensor, tensor) -> tensor + %444 = "tf.Less"(%430, %443) {device = ""} : (tensor, tensor) -> tensor + %445 = "tf.Reshape"(%444, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %446 = "tf.Where"(%445) {device = ""} : (tensor) -> tensor + %447 = "tf.Squeeze"(%446) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %448 = "tf.GatherV2"(%442, %447, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %449 = "tf.Cast"(%425) {Truncate = false, device = ""} : (tensor) -> tensor + %450 = "tf.BroadcastTo"(%449, %404) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %451 = "tf.Max"(%450, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %452 = "tf.Maximum"(%14, %451) {device = ""} : (tensor, tensor) -> tensor + %453 = "tf.Range"(%14, %452, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %454 = "tf.Pack"(%7, %452) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %455 = "tf.Tile"(%401, %454) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %456 = "tf.Shape"(%455) {device = ""} : (tensor) -> tensor<2xi32> + %457 = "tf.StridedSlice"(%456, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %458 = "tf.Prod"(%457, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %459 = "tf.Pack"(%458) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %460 = "tf.Shape"(%455) {device = ""} : (tensor) -> tensor<2xi32> + %461 = "tf.StridedSlice"(%460, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %462 = "tf.Shape"(%455) {device = ""} : (tensor) -> tensor<2xi32> + %463 = "tf.StridedSlice"(%462, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %464 = "tf.ConcatV2"(%461, %459, %463, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %465 = "tf.Reshape"(%455, %464) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %466 = "tf.ExpandDims"(%450, %3) {device = ""} : (tensor, tensor) -> tensor + %467 = "tf.Less"(%453, %466) {device = ""} : (tensor, tensor) -> tensor + %468 = "tf.Reshape"(%467, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %469 = "tf.Where"(%468) {device = ""} : (tensor) -> tensor + %470 = "tf.Squeeze"(%469) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %471 = "tf.GatherV2"(%465, %470, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %472:2 = "tf.RaggedRange"(%448, %471, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %473 = "tf.GatherV2"(%389, %472#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %474 = "tf.If"(%315, %315, %311, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_Assert_1_AssertGuard_false_16370, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_Assert_1_AssertGuard_true_16360} : (tensor, tensor, tensor, tensor) -> tensor + %475 = "tf.Identity"(%474) {device = ""} : (tensor) -> tensor + %476 = "tf.Select"(%2, %311, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %477 = "tf.Pack"(%476) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %478 = "tf.ConcatV2"(%1, %477, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> + %479 = "tf.StridedSlice"(%478, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %480 = "tf.Equal"(%479, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %481 = "tf.StridedSlice"(%478, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %482 = "tf.StridedSlice"(%478, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %483 = "tf.Equal"(%482, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %484 = "tf.If"(%483, %483, %482, %425) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_Assert_2_AssertGuard_false_16860, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_Assert_2_AssertGuard_true_16850} : (tensor, tensor, tensor, tensor) -> tensor + %485 = "tf.Identity"(%484) {device = ""} : (tensor) -> tensor + %486 = "tf.If"(%480, %480, %425, %481) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_Assert_3_AssertGuard_false_17220, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_Assert_3_AssertGuard_true_17210} : (tensor, tensor, tensor, tensor) -> tensor + %487 = "tf.Identity"(%486) {device = ""} : (tensor) -> tensor + %488 = "tf.If"(%351, %351, %13, %347) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_21900, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_21890} : (tensor, tensor, tensor, tensor) -> tensor + %489 = "tf.Identity"(%488) {device = ""} : (tensor) -> tensor + %490 = "tf.Equal"(%347, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %491 = "tf.Select"(%490, %13, %347) {device = ""} : (tensor, tensor, tensor) -> tensor + %492 = "tf.Equal"(%491, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %493 = "tf.LogicalOr"(%492, %2) {device = ""} : (tensor, tensor) -> tensor + %494 = "tf.Equal"(%491, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %495 = "tf.LogicalOr"(%493, %494) {device = ""} : (tensor, tensor) -> tensor + %496 = "tf.Select"(%358, %491, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %497 = "tf.Pack"(%496, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %498 = "tf.StridedSlice"(%497, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %499 = "tf.Cast"(%498) {Truncate = false, device = ""} : (tensor) -> tensor + %500 = "tf.Reshape"(%499, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %501 = "tf.Pack"(%7, %500) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %502 = "tf.Tile"(%359, %501) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %503 = "tf.Mul"(%500, %362) {device = ""} : (tensor, tensor) -> tensor + %504 = "tf.Pack"(%503) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %505 = "tf.ConcatV2"(%361, %504, %363, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %506 = "tf.Reshape"(%502, %505) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %507 = "tf.Shape"(%506) {device = ""} : (tensor) -> tensor<1xi64> + %508 = "tf.StridedSlice"(%507, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %509 = "tf.Pack"(%498) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %510 = "tf.StridedSlice"(%506, %509, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %511 = "tf.Sub"(%508, %498) {device = ""} : (tensor, tensor) -> tensor + %512 = "tf.Pack"(%511) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %513 = "tf.StridedSlice"(%506, %11, %512, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %514:2 = "tf.RaggedRange"(%513, %510, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %515 = "tf.Select"(%2, %491, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %516 = "tf.Pack"(%515, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %517 = "tf.StridedSlice"(%516, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %518 = "tf.Cast"(%517) {Truncate = false, device = ""} : (tensor) -> tensor + %519 = "tf.Reshape"(%518, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %520 = "tf.Pack"(%7, %519) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %521 = "tf.Tile"(%4, %520) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %522 = "tf.Mul"(%519, %8) {device = ""} : (tensor, tensor) -> tensor + %523 = "tf.Pack"(%522) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %524 = "tf.ConcatV2"(%9, %523, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %525 = "tf.Reshape"(%521, %524) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %526 = "tf.Shape"(%525) {device = ""} : (tensor) -> tensor<1xi64> + %527 = "tf.StridedSlice"(%526, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %528 = "tf.Pack"(%517) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %529 = "tf.StridedSlice"(%525, %528, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %530 = "tf.Sub"(%527, %517) {device = ""} : (tensor, tensor) -> tensor + %531 = "tf.Pack"(%530) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %532 = "tf.StridedSlice"(%525, %11, %531, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %533:2 = "tf.RaggedRange"(%532, %529, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %534 = "tf.StridedSlice"(%516, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %535 = "tf.StridedSlice"(%516, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %536 = "tf.Mul"(%535, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> + %537 = "tf.Tile"(%536, %534) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor + %538 = "tf.Cumsum"(%537, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %539 = "tf.ConcatV2"(%11, %538, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %540 = "tf.StridedSlice"(%539, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %541 = "tf.ExpandDims"(%540, %7) {device = ""} : (tensor, tensor) -> tensor + %542 = "tf.Shape"(%540) {device = ""} : (tensor) -> tensor<1xi32> + %543 = "tf.StridedSlice"(%542, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %544 = "tf.Pack"(%543) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %545 = "tf.StridedSlice"(%539, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %546 = "tf.ExpandDims"(%545, %7) {device = ""} : (tensor, tensor) -> tensor + %547 = "tf.Shape"(%545) {device = ""} : (tensor) -> tensor<1xi32> + %548 = "tf.StridedSlice"(%547, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %549 = "tf.Pack"(%548) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %550 = "tf.Equal"(%347, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %551 = "tf.Select"(%550, %491, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %552 = "tf.Cast"(%551) {Truncate = false, device = ""} : (tensor) -> tensor + %553 = "tf.Reshape"(%552, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %554 = "tf.Pack"(%7, %553) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %555 = "tf.Mul"(%553, %8) {device = ""} : (tensor, tensor) -> tensor + %556 = "tf.Pack"(%555) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %557 = "tf.ConcatV2"(%9, %556, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %558 = "tf.Pack"(%551) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %559 = "tf.Pack"(%10, %347) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %560 = "tf.ExpandDims"(%559, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> + %561 = "tf.Tile"(%560, %554) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %562 = "tf.Reshape"(%561, %557) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %563 = "tf.Shape"(%562) {device = ""} : (tensor) -> tensor<1xi64> + %564 = "tf.StridedSlice"(%563, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %565 = "tf.Sub"(%564, %551) {device = ""} : (tensor, tensor) -> tensor + %566 = "tf.Pack"(%565) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %567 = "tf.StridedSlice"(%562, %11, %566, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %568 = "tf.StridedSlice"(%562, %558, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %569:2 = "tf.RaggedRange"(%567, %568, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %570 = "tf.GatherV2"(%354, %569#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %571 = "tf.Cast"(%570) {Truncate = false, device = ""} : (tensor) -> tensor + %572 = "tf.BroadcastTo"(%571, %544) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %573 = "tf.Max"(%572, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %574 = "tf.Maximum"(%14, %573) {device = ""} : (tensor, tensor) -> tensor + %575 = "tf.Range"(%14, %574, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %576 = "tf.Pack"(%7, %574) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %577 = "tf.Tile"(%541, %576) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %578 = "tf.Shape"(%577) {device = ""} : (tensor) -> tensor<2xi32> + %579 = "tf.StridedSlice"(%578, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %580 = "tf.Prod"(%579, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %581 = "tf.Pack"(%580) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %582 = "tf.Shape"(%577) {device = ""} : (tensor) -> tensor<2xi32> + %583 = "tf.StridedSlice"(%582, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %584 = "tf.Shape"(%577) {device = ""} : (tensor) -> tensor<2xi32> + %585 = "tf.StridedSlice"(%584, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %586 = "tf.ConcatV2"(%583, %581, %585, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %587 = "tf.Reshape"(%577, %586) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %588 = "tf.ExpandDims"(%572, %3) {device = ""} : (tensor, tensor) -> tensor + %589 = "tf.Less"(%575, %588) {device = ""} : (tensor, tensor) -> tensor + %590 = "tf.Reshape"(%589, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %591 = "tf.Where"(%590) {device = ""} : (tensor) -> tensor + %592 = "tf.Squeeze"(%591) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %593 = "tf.GatherV2"(%587, %592, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %594 = "tf.Cast"(%570) {Truncate = false, device = ""} : (tensor) -> tensor + %595 = "tf.BroadcastTo"(%594, %549) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %596 = "tf.Max"(%595, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %597 = "tf.Maximum"(%14, %596) {device = ""} : (tensor, tensor) -> tensor + %598 = "tf.Range"(%14, %597, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %599 = "tf.Pack"(%7, %597) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %600 = "tf.Tile"(%546, %599) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %601 = "tf.Shape"(%600) {device = ""} : (tensor) -> tensor<2xi32> + %602 = "tf.StridedSlice"(%601, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %603 = "tf.Prod"(%602, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %604 = "tf.Pack"(%603) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %605 = "tf.Shape"(%600) {device = ""} : (tensor) -> tensor<2xi32> + %606 = "tf.StridedSlice"(%605, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %607 = "tf.Shape"(%600) {device = ""} : (tensor) -> tensor<2xi32> + %608 = "tf.StridedSlice"(%607, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %609 = "tf.ConcatV2"(%606, %604, %608, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %610 = "tf.Reshape"(%600, %609) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %611 = "tf.ExpandDims"(%595, %3) {device = ""} : (tensor, tensor) -> tensor + %612 = "tf.Less"(%598, %611) {device = ""} : (tensor, tensor) -> tensor + %613 = "tf.Reshape"(%612, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %614 = "tf.Where"(%613) {device = ""} : (tensor) -> tensor + %615 = "tf.Squeeze"(%614) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %616 = "tf.GatherV2"(%610, %615, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %617:2 = "tf.RaggedRange"(%593, %616, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %618 = "tf.If"(%495, %495, %491, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_22970, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_22960} : (tensor, tensor, tensor, tensor) -> tensor + %619 = "tf.Identity"(%618) {device = ""} : (tensor) -> tensor + %620 = "tf.Select"(%2, %491, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %621 = "tf.Pack"(%620) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %622 = "tf.ConcatV2"(%1, %621, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> + %623 = "tf.StridedSlice"(%622, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %624 = "tf.Equal"(%623, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %625 = "tf.StridedSlice"(%622, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %626 = "tf.StridedSlice"(%622, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %627 = "tf.Equal"(%626, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %628 = "tf.If"(%627, %627, %626, %570) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_23460, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_23450} : (tensor, tensor, tensor, tensor) -> tensor + %629 = "tf.Identity"(%628) {device = ""} : (tensor) -> tensor + %630 = "tf.If"(%624, %624, %570, %625) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_23820, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23810} : (tensor, tensor, tensor, tensor) -> tensor + %631 = "tf.Identity"(%79) {device = ""} : (tensor) -> tensor + %632 = "tf.Identity"(%630) {device = ""} : (tensor) -> tensor + %633 = "tf.Identity"(%307) {device = ""} : (tensor) -> tensor + %634 = "tf.Shape"(%36#2) {device = ""} : (tensor) -> tensor<1xi32> + %635 = "tf.StridedSlice"(%634, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %636 = "tf.Cast"(%635) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> + %637 = "tf.Identity"(%636) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> + %638 = "tf.Shape"(%36#3) {device = ""} : (tensor) -> tensor<1xi32> + %639 = "tf.StridedSlice"(%638, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %640 = "tf.Cast"(%639) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> + %641 = "tf.Identity"(%640) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> + %642 = "tf.GatherV2"(%36#3, %335, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %643 = "tf.Tile"(%642, %339) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %644 = "tf.Sub"(%643, %473) {device = ""} : (tensor, tensor) -> tensor + %645 = "tf.Shape"(%644) {device = ""} : (tensor) -> tensor<1xi32> + %646 = "tf.StridedSlice"(%645, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %647 = "tf.Cast"(%646) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> + %648 = "tf.Identity"(%647) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> + %649 = "tf.UnicodeEncode"(%36#0, %57) {Tsplits = i64, device = "", errors = "replace", output_encoding = "UTF-8", replacement_char = 65533 : i64} : (tensor, tensor) -> tensor + %650 = "tf.Identity"(%649) {device = ""} : (tensor) -> tensor + return %650, %631 : tensor, tensor } +func @WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_3210(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Input tensors have incompatible shapes."> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedConcat/RaggedFromTensor/Const:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedConcat/RaggedNRows/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_3200(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_3970(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_3960(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_4330(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_4320(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_4670(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_4660(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_5040(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_5030(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_5400(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_5390(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_5760(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RaggedNRows/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_5750(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6110(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6100(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_6470(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_6460(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6810(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6800(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7180(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7170(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_7540(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_7530(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7880(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7870(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_8680(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_8670(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_9750(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_9740(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_10240(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_10230(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_10600(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_10590(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_Assert_AssertGuard_false_15300(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_Assert_AssertGuard_true_15290(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_Assert_1_AssertGuard_false_16370(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_Assert_1_AssertGuard_true_16360(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_Assert_2_AssertGuard_false_16860(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_Assert_2_AssertGuard_true_16850(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_Assert_3_AssertGuard_false_17220(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_Assert_3_AssertGuard_true_17210(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_21900(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_21890(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_22970(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_22960(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_23460(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_23450(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_23820(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23810(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} + +// CHECK: func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor, tensor) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} { +// CHECK: %0:2 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor, tensor) +// CHECK: return %0#0, %0#1 : tensor, tensor + +func @whitespace_tokenizer_rank2(%arg0: tensor {tf._user_specified_name = "input"}) -> (tensor, tensor, tensor) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64> + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<[[0], [1]]> : tensor<2x1xi64>} : () -> tensor<2x1xi64> + %4 = "tf.Const"() {value = dense<[2, -1]> : tensor<2xi32>} : () -> tensor<2xi32> + %5 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %6 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %7 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %8 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %9 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %10 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %11 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %12 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %13 = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64> + %14 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> + %15 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %16 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %17 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %18 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %19 = "tf.Shape"(%arg0) {device = ""} : (tensor) -> tensor<2xi64> + %20 = "tf.StridedSlice"(%19, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %21 = "tf.StridedSlice"(%19, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %22 = "tf.Mul"(%20, %21) {device = ""} : (tensor, tensor) -> tensor + %23 = "tf.Pack"(%22) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %24 = "tf.StridedSlice"(%19, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> + %25 = "tf.ConcatV2"(%23, %24, %16) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> + %26 = "tf.Reshape"(%arg0, %25) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %27 = "tf.StringLength"(%26) {device = "", unit = "BYTE"} : (tensor) -> tensor + %28 = "tf.ExpandDims"(%27, %9) {device = ""} : (tensor, tensor) -> tensor + %29 = "tf.Cast"(%28) {Truncate = false, device = ""} : (tensor) -> tensor + %30 = "tf.Shape"(%29) {device = ""} : (tensor) -> tensor<2xi64> + %31 = "tf.StridedSlice"(%30, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %32 = "tf.StridedSlice"(%30, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %33 = "tf.Mul"(%31, %32) {device = ""} : (tensor, tensor) -> tensor + %34 = "tf.Pack"(%33) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %35 = "tf.StridedSlice"(%30, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> + %36 = "tf.ConcatV2"(%34, %35, %16) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> + %37 = "tf.Reshape"(%29, %36) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %38 = "tf.StridedSlice"(%30, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %39 = "tf.AddV2"(%38, %15) {device = ""} : (tensor, tensor) -> tensor + %40 = "tf.Range"(%12, %39, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %41 = "tf.Mul"(%40, %15) {device = ""} : (tensor, tensor) -> tensor + %42 = "tf.Reshape"(%26, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %43:3 = "tf.UnicodeDecodeWithOffsets"(%42) {Tsplits = i64, device = "", errors = "replace", input_encoding = "UTF-8", replace_control_characters = false, replacement_char = 65533 : i64} : (tensor) -> (tensor, tensor, tensor) + %44 = "tf.StridedSlice"(%43#0, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %45 = "tf.Shape"(%44) {device = ""} : (tensor) -> tensor<1xi32> + %46 = "tf.ConcatV2"(%45, %18, %16) {device = ""} : (tensor<1xi32>, tensor<1xi32>, tensor) -> tensor<2xi32> + %47 = "tf.Reshape"(%44, %46) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %48 = "tf.Shape"(%47) {device = ""} : (tensor) -> tensor<2xi64> + %49 = "tf.StridedSlice"(%48, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %50 = "tf.AddV2"(%49, %15) {device = ""} : (tensor, tensor) -> tensor + %51 = "tf.Range"(%12, %50, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %52 = "tf.Mul"(%51, %15) {device = ""} : (tensor, tensor) -> tensor + %53 = "tf.ExpandDims"(%52, %9) {device = ""} : (tensor, tensor) -> tensor + %54 = "tf.Shape"(%52) {device = ""} : (tensor) -> tensor<1xi32> + %55 = "tf.StridedSlice"(%54, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %56 = "tf.StridedSlice"(%54, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %57 = "tf.StridedSlice"(%54, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %58 = "tf.StridedSlice"(%52, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %59 = "tf.StridedSlice"(%52, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %60 = "tf.Sub"(%58, %59) {device = ""} : (tensor, tensor) -> tensor + %61 = "tf.Shape"(%47) {device = ""} : (tensor) -> tensor<2xi32> + %62 = "tf.Cast"(%61) {Truncate = false, device = ""} : (tensor<2xi32>) -> tensor<2xi64> + %63 = "tf.StridedSlice"(%62, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %64 = "tf.Equal"(%63, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %65 = "tf.StridedSlice"(%62, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %66 = "tf.Equal"(%65, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %67 = "tf.StridedSlice"(%62, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %68 = "tf.Shape"(%47) {device = ""} : (tensor) -> tensor<2xi32> + %69 = "tf.Cast"(%68) {Truncate = false, device = ""} : (tensor<2xi32>) -> tensor<2xi64> + %70 = "tf.StridedSlice"(%69, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %71 = "tf.Equal"(%70, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %72 = "tf.StridedSlice"(%43#0, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %73 = "tf.AddV2"(%72, %15) {device = ""} : (tensor, tensor) -> tensor + %74 = "tf.StridedSlice"(%43#0, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %75 = "tf.Minimum"(%73, %74) {device = ""} : (tensor, tensor) -> tensor + %76:2 = "tf.RaggedRange"(%75, %74, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %77 = "tf.Shape"(%76#0) {device = ""} : (tensor) -> tensor<1xi64> + %78 = "tf.StridedSlice"(%77, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %79 = "tf.Sub"(%78, %15) {device = ""} : (tensor, tensor) -> tensor + %80 = "tf.Equal"(%38, %79) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %81 = "tf.All"(%80, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %82 = "tf.If"(%81, %81, %38, %79) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_99640, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_99630} : (tensor, tensor, tensor, tensor) -> tensor + %83 = "tf.Identity"(%82) {device = ""} : (tensor) -> tensor + %84 = "tf.StridedSlice"(%41, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %85 = "tf.Mul"(%79, %5) {device = ""} : (tensor, tensor) -> tensor + %86 = "tf.Range"(%12, %85, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %87 = "tf.Reshape"(%86, %4) {device = ""} : (tensor, tensor<2xi32>) -> tensor<2x?xi64> + %88 = "tf.Transpose"(%87, %8) {device = ""} : (tensor<2x?xi64>, tensor<2xi32>) -> tensor + %89 = "tf.Reshape"(%88, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %90 = "tf.StridedSlice"(%76#0, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %91 = "tf.AddV2"(%84, %90) {device = ""} : (tensor, tensor) -> tensor + %92 = "tf.ConcatV2"(%76#0, %91, %16) {device = ""} : (tensor, tensor, tensor) -> tensor + %93 = "tf.GatherV2"(%43#2, %76#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %94 = "tf.ConcatV2"(%93, %37, %16) {device = ""} : (tensor, tensor, tensor) -> tensor + %95:2 = "tf.RaggedGather"(%92, %94, %89) {OUTPUT_RAGGED_RANK = 1 : i64, PARAMS_RAGGED_RANK = 1 : i64, Tindices = i64, Tsplits = i64, Tvalues = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %96 = "tf.StridedSlice"(%95#0, %17, %17, %7) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %97 = "tf.StridedSlice"(%96, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %98 = "tf.Shape"(%97) {device = ""} : (tensor) -> tensor<1xi32> + %99 = "tf.ConcatV2"(%98, %18, %16) {device = ""} : (tensor<1xi32>, tensor<1xi32>, tensor) -> tensor<2xi32> + %100 = "tf.Reshape"(%97, %99) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %101 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<2xi64> + %102 = "tf.StridedSlice"(%101, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %103 = "tf.AddV2"(%102, %15) {device = ""} : (tensor, tensor) -> tensor + %104 = "tf.Range"(%12, %103, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %105 = "tf.Mul"(%104, %15) {device = ""} : (tensor, tensor) -> tensor + %106 = "tf.ExpandDims"(%105, %9) {device = ""} : (tensor, tensor) -> tensor + %107 = "tf.Shape"(%105) {device = ""} : (tensor) -> tensor<1xi32> + %108 = "tf.StridedSlice"(%107, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %109 = "tf.StridedSlice"(%107, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %110 = "tf.StridedSlice"(%107, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %111 = "tf.StridedSlice"(%105, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %112 = "tf.StridedSlice"(%105, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %113 = "tf.Sub"(%111, %112) {device = ""} : (tensor, tensor) -> tensor + %114 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<2xi32> + %115 = "tf.Cast"(%114) {Truncate = false, device = ""} : (tensor<2xi32>) -> tensor<2xi64> + %116 = "tf.StridedSlice"(%115, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %117 = "tf.Equal"(%116, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %118 = "tf.StridedSlice"(%115, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %119 = "tf.Equal"(%118, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %120 = "tf.StridedSlice"(%115, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %121 = "tf.Shape"(%100) {device = ""} : (tensor) -> tensor<2xi32> + %122 = "tf.Cast"(%121) {Truncate = false, device = ""} : (tensor<2xi32>) -> tensor<2xi64> + %123 = "tf.StridedSlice"(%122, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %124 = "tf.Equal"(%123, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %125:5 = "tf.WhitespaceTokenizeWithOffsets"(%43#1, %43#0) {Tsplits = i64, device = ""} : (tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor) + %126 = "tf.StridedSlice"(%125#1, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %127 = "tf.Equal"(%126, %12) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %128 = "tf.All"(%127, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %129 = "tf.If"(%128, %128, %126, %12) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_100400, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_100390} : (tensor, tensor, tensor, tensor) -> tensor + %130 = "tf.Identity"(%129) {device = ""} : (tensor) -> tensor + %131 = "tf.StridedSlice"(%125#1, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %132 = "tf.StridedSlice"(%125#1, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %133 = "tf.Sub"(%131, %132) {device = ""} : (tensor, tensor) -> tensor + %134 = "tf.LessEqual"(%12, %133) {device = ""} : (tensor, tensor) -> tensor + %135 = "tf.All"(%134, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %136 = "tf.If"(%135, %135, %133) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_100760, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_100750} : (tensor, tensor, tensor) -> tensor + %137 = "tf.Identity"(%136) {device = ""} : (tensor) -> tensor + %138 = "tf.Identity"(%125#1) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %139 = "tf.StridedSlice"(%138, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %140 = "tf.Shape"(%125#0) {device = ""} : (tensor) -> tensor<1xi64> + %141 = "tf.StridedSlice"(%140, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %142 = "tf.Equal"(%139, %141) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %143 = "tf.All"(%142, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %144 = "tf.If"(%143, %143, %139, %141) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_101100, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_101090} : (tensor, tensor, tensor, tensor) -> tensor + %145 = "tf.Identity"(%144) {device = ""} : (tensor) -> tensor + %146 = "tf.Identity"(%138) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %147 = "tf.Shape"(%146) {device = ""} : (tensor) -> tensor<1xi64> + %148 = "tf.StridedSlice"(%147, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %149 = "tf.Sub"(%148, %15) {device = ""} : (tensor, tensor) -> tensor + %150 = "tf.StridedSlice"(%125#4, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %151 = "tf.Equal"(%150, %12) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %152 = "tf.All"(%151, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %153 = "tf.If"(%152, %152, %150, %12) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_101470, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_101460} : (tensor, tensor, tensor, tensor) -> tensor + %154 = "tf.Identity"(%153) {device = ""} : (tensor) -> tensor + %155 = "tf.StridedSlice"(%125#4, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %156 = "tf.StridedSlice"(%125#4, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %157 = "tf.Sub"(%155, %156) {device = ""} : (tensor, tensor) -> tensor + %158 = "tf.LessEqual"(%12, %157) {device = ""} : (tensor, tensor) -> tensor + %159 = "tf.All"(%158, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %160 = "tf.If"(%159, %159, %157) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_101830, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_101820} : (tensor, tensor, tensor) -> tensor + %161 = "tf.Identity"(%160) {device = ""} : (tensor) -> tensor + %162 = "tf.Identity"(%125#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %163 = "tf.StridedSlice"(%162, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %164 = "tf.Equal"(%163, %149) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %165 = "tf.All"(%164, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %166 = "tf.If"(%165, %165, %163, %149) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_102190, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_102180} : (tensor, tensor, tensor, tensor) -> tensor + %167 = "tf.Identity"(%166) {device = ""} : (tensor) -> tensor + %168 = "tf.Identity"(%162) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %169 = "tf.StridedSlice"(%125#4, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %170 = "tf.Equal"(%169, %12) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %171 = "tf.All"(%170, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %172 = "tf.If"(%171, %171, %169, %12) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_102540, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_102530} : (tensor, tensor, tensor, tensor) -> tensor + %173 = "tf.Identity"(%172) {device = ""} : (tensor) -> tensor + %174 = "tf.StridedSlice"(%125#4, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %175 = "tf.StridedSlice"(%125#4, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %176 = "tf.Sub"(%174, %175) {device = ""} : (tensor, tensor) -> tensor + %177 = "tf.LessEqual"(%12, %176) {device = ""} : (tensor, tensor) -> tensor + %178 = "tf.All"(%177, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %179 = "tf.If"(%178, %178, %176) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_102900, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_102890} : (tensor, tensor, tensor) -> tensor + %180 = "tf.Identity"(%179) {device = ""} : (tensor) -> tensor + %181 = "tf.Identity"(%125#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %182 = "tf.StridedSlice"(%181, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %183 = "tf.Shape"(%125#2) {device = ""} : (tensor) -> tensor<1xi64> + %184 = "tf.StridedSlice"(%183, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %185 = "tf.Equal"(%182, %184) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %186 = "tf.All"(%185, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %187 = "tf.If"(%186, %186, %182, %184) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_103240, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_103230} : (tensor, tensor, tensor, tensor) -> tensor + %188 = "tf.Identity"(%187) {device = ""} : (tensor) -> tensor + %189 = "tf.Identity"(%181) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %190 = "tf.Shape"(%189) {device = ""} : (tensor) -> tensor<1xi64> + %191 = "tf.StridedSlice"(%190, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %192 = "tf.Sub"(%191, %15) {device = ""} : (tensor, tensor) -> tensor + %193 = "tf.Equal"(%192, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %194 = "tf.LogicalOr"(%64, %193) {device = ""} : (tensor, tensor) -> tensor + %195 = "tf.Equal"(%192, %63) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %196 = "tf.LogicalOr"(%194, %195) {device = ""} : (tensor, tensor) -> tensor + %197 = "tf.StridedSlice"(%189, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %198 = "tf.StridedSlice"(%189, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %199 = "tf.Sub"(%197, %198) {device = ""} : (tensor, tensor) -> tensor + %200 = "tf.Shape"(%189) {device = ""} : (tensor) -> tensor<1xi64> + %201 = "tf.StridedSlice"(%200, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %202 = "tf.Sub"(%201, %15) {device = ""} : (tensor, tensor) -> tensor + %203 = "tf.Equal"(%202, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %204 = "tf.ExpandDims"(%189, %9) {device = ""} : (tensor, tensor) -> tensor + %205 = "tf.Shape"(%189) {device = ""} : (tensor) -> tensor<1xi32> + %206 = "tf.StridedSlice"(%205, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %207 = "tf.StridedSlice"(%205, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %208 = "tf.StridedSlice"(%205, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %209 = "tf.StridedSlice"(%125#4, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %210 = "tf.Equal"(%209, %12) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %211 = "tf.All"(%210, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %212 = "tf.If"(%211, %211, %209, %12) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_103610, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_103600} : (tensor, tensor, tensor, tensor) -> tensor + %213 = "tf.Identity"(%212) {device = ""} : (tensor) -> tensor + %214 = "tf.StridedSlice"(%125#4, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %215 = "tf.StridedSlice"(%125#4, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %216 = "tf.Sub"(%214, %215) {device = ""} : (tensor, tensor) -> tensor + %217 = "tf.LessEqual"(%12, %216) {device = ""} : (tensor, tensor) -> tensor + %218 = "tf.All"(%217, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %219 = "tf.If"(%218, %218, %216) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_103970, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_103960} : (tensor, tensor, tensor) -> tensor + %220 = "tf.Identity"(%219) {device = ""} : (tensor) -> tensor + %221 = "tf.Identity"(%125#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %222 = "tf.StridedSlice"(%221, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %223 = "tf.Shape"(%125#3) {device = ""} : (tensor) -> tensor<1xi64> + %224 = "tf.StridedSlice"(%223, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %225 = "tf.Equal"(%222, %224) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %226 = "tf.All"(%225, %11) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %227 = "tf.If"(%226, %226, %222, %224) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_104310, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_104300} : (tensor, tensor, tensor, tensor) -> tensor + %228 = "tf.Identity"(%227) {device = ""} : (tensor) -> tensor + %229 = "tf.Identity"(%221) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %230 = "tf.Shape"(%229) {device = ""} : (tensor) -> tensor<1xi64> + %231 = "tf.StridedSlice"(%230, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %232 = "tf.Sub"(%231, %15) {device = ""} : (tensor, tensor) -> tensor + %233 = "tf.Equal"(%232, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %234 = "tf.LogicalOr"(%233, %1) {device = ""} : (tensor, tensor) -> tensor + %235 = "tf.Equal"(%232, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %236 = "tf.LogicalOr"(%234, %235) {device = ""} : (tensor, tensor) -> tensor + %237 = "tf.StridedSlice"(%229, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %238 = "tf.StridedSlice"(%229, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %239 = "tf.Sub"(%237, %238) {device = ""} : (tensor, tensor) -> tensor + %240 = "tf.Shape"(%229) {device = ""} : (tensor) -> tensor<1xi64> + %241 = "tf.StridedSlice"(%240, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %242 = "tf.Sub"(%241, %15) {device = ""} : (tensor, tensor) -> tensor + %243 = "tf.Equal"(%242, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %244 = "tf.ExpandDims"(%229, %9) {device = ""} : (tensor, tensor) -> tensor + %245 = "tf.Shape"(%229) {device = ""} : (tensor) -> tensor<1xi32> + %246 = "tf.StridedSlice"(%245, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %247 = "tf.StridedSlice"(%245, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %248 = "tf.StridedSlice"(%245, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %249 = "tf.StridedSlice"(%229, %6, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %250 = "tf.Range"(%12, %249, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %251 = "tf.StridedSlice"(%229, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %252 = "tf.StridedSlice"(%229, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %253 = "tf.Sub"(%251, %252) {device = ""} : (tensor, tensor) -> tensor + %254 = "tf.If"(%196, %196, %63, %192) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_105110, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_105100} : (tensor, tensor, tensor, tensor) -> tensor + %255 = "tf.Identity"(%254) {device = ""} : (tensor) -> tensor + %256 = "tf.Equal"(%192, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %257 = "tf.Select"(%256, %63, %192) {device = ""} : (tensor, tensor, tensor) -> tensor + %258 = "tf.Equal"(%257, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %259 = "tf.LogicalOr"(%258, %66) {device = ""} : (tensor, tensor) -> tensor + %260 = "tf.Equal"(%65, %257) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %261 = "tf.LogicalOr"(%259, %260) {device = ""} : (tensor, tensor) -> tensor + %262 = "tf.Select"(%203, %257, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %263 = "tf.Pack"(%262, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %264 = "tf.StridedSlice"(%263, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %265 = "tf.Cast"(%264) {Truncate = false, device = ""} : (tensor) -> tensor + %266 = "tf.Reshape"(%265, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %267 = "tf.Pack"(%9, %266) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %268 = "tf.Tile"(%204, %267) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %269 = "tf.Mul"(%266, %207) {device = ""} : (tensor, tensor) -> tensor + %270 = "tf.Pack"(%269) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %271 = "tf.ConcatV2"(%206, %270, %208, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %272 = "tf.Reshape"(%268, %271) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %273 = "tf.Shape"(%272) {device = ""} : (tensor) -> tensor<1xi64> + %274 = "tf.StridedSlice"(%273, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %275 = "tf.Pack"(%264) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %276 = "tf.StridedSlice"(%272, %275, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %277 = "tf.Sub"(%274, %264) {device = ""} : (tensor, tensor) -> tensor + %278 = "tf.Pack"(%277) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %279 = "tf.StridedSlice"(%272, %13, %278, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %280:2 = "tf.RaggedRange"(%279, %276, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %281 = "tf.Select"(%71, %257, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %282 = "tf.Pack"(%281, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %283 = "tf.StridedSlice"(%282, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %284 = "tf.Cast"(%283) {Truncate = false, device = ""} : (tensor) -> tensor + %285 = "tf.Reshape"(%284, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %286 = "tf.Pack"(%9, %285) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %287 = "tf.Tile"(%53, %286) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %288 = "tf.Mul"(%285, %56) {device = ""} : (tensor, tensor) -> tensor + %289 = "tf.Pack"(%288) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %290 = "tf.ConcatV2"(%55, %289, %57, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %291 = "tf.Reshape"(%287, %290) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %292 = "tf.Shape"(%291) {device = ""} : (tensor) -> tensor<1xi64> + %293 = "tf.StridedSlice"(%292, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %294 = "tf.Pack"(%283) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %295 = "tf.StridedSlice"(%291, %294, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %296 = "tf.Sub"(%293, %283) {device = ""} : (tensor, tensor) -> tensor + %297 = "tf.Pack"(%296) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %298 = "tf.StridedSlice"(%291, %13, %297, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %299:2 = "tf.RaggedRange"(%298, %295, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %300 = "tf.StridedSlice"(%282, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %301 = "tf.StridedSlice"(%282, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %302 = "tf.Mul"(%60, %301) {device = ""} : (tensor, tensor) -> tensor + %303 = "tf.Tile"(%302, %300) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %304 = "tf.Cumsum"(%303, %16) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %305 = "tf.ConcatV2"(%13, %304, %2) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %306 = "tf.StridedSlice"(%305, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %307 = "tf.ExpandDims"(%306, %9) {device = ""} : (tensor, tensor) -> tensor + %308 = "tf.Shape"(%306) {device = ""} : (tensor) -> tensor<1xi32> + %309 = "tf.StridedSlice"(%308, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %310 = "tf.Pack"(%309) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %311 = "tf.StridedSlice"(%305, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %312 = "tf.ExpandDims"(%311, %9) {device = ""} : (tensor, tensor) -> tensor + %313 = "tf.Shape"(%311) {device = ""} : (tensor) -> tensor<1xi32> + %314 = "tf.StridedSlice"(%313, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %315 = "tf.Pack"(%314) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %316 = "tf.Equal"(%192, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %317 = "tf.Select"(%316, %257, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %318 = "tf.Cast"(%317) {Truncate = false, device = ""} : (tensor) -> tensor + %319 = "tf.Reshape"(%318, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %320 = "tf.Pack"(%9, %319) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %321 = "tf.Mul"(%319, %10) {device = ""} : (tensor, tensor) -> tensor + %322 = "tf.Pack"(%321) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %323 = "tf.ConcatV2"(%11, %322, %11, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %324 = "tf.Pack"(%317) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %325 = "tf.Pack"(%12, %192) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %326 = "tf.ExpandDims"(%325, %9) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> + %327 = "tf.Tile"(%326, %320) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %328 = "tf.Reshape"(%327, %323) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %329 = "tf.Shape"(%328) {device = ""} : (tensor) -> tensor<1xi64> + %330 = "tf.StridedSlice"(%329, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %331 = "tf.Sub"(%330, %317) {device = ""} : (tensor, tensor) -> tensor + %332 = "tf.Pack"(%331) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %333 = "tf.StridedSlice"(%328, %13, %332, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %334 = "tf.StridedSlice"(%328, %324, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %335:2 = "tf.RaggedRange"(%333, %334, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %336 = "tf.GatherV2"(%199, %335#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %337 = "tf.Cast"(%336) {Truncate = false, device = ""} : (tensor) -> tensor + %338 = "tf.BroadcastTo"(%337, %310) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %339 = "tf.Max"(%338, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %340 = "tf.Maximum"(%16, %339) {device = ""} : (tensor, tensor) -> tensor + %341 = "tf.Range"(%16, %340, %9) {device = ""} : (tensor, tensor, tensor) -> tensor + %342 = "tf.Pack"(%9, %340) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %343 = "tf.Tile"(%307, %342) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %344 = "tf.Shape"(%343) {device = ""} : (tensor) -> tensor<2xi32> + %345 = "tf.StridedSlice"(%344, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %346 = "tf.Prod"(%345, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %347 = "tf.Pack"(%346) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %348 = "tf.Shape"(%343) {device = ""} : (tensor) -> tensor<2xi32> + %349 = "tf.StridedSlice"(%348, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %350 = "tf.Shape"(%343) {device = ""} : (tensor) -> tensor<2xi32> + %351 = "tf.StridedSlice"(%350, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %352 = "tf.ConcatV2"(%349, %347, %351, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %353 = "tf.Reshape"(%343, %352) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %354 = "tf.ExpandDims"(%338, %2) {device = ""} : (tensor, tensor) -> tensor + %355 = "tf.Less"(%341, %354) {device = ""} : (tensor, tensor) -> tensor + %356 = "tf.Reshape"(%355, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %357 = "tf.Where"(%356) {device = ""} : (tensor) -> tensor + %358 = "tf.Squeeze"(%357) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %359 = "tf.GatherV2"(%353, %358, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %360 = "tf.Cast"(%336) {Truncate = false, device = ""} : (tensor) -> tensor + %361 = "tf.BroadcastTo"(%360, %315) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %362 = "tf.Max"(%361, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %363 = "tf.Maximum"(%16, %362) {device = ""} : (tensor, tensor) -> tensor + %364 = "tf.Range"(%16, %363, %9) {device = ""} : (tensor, tensor, tensor) -> tensor + %365 = "tf.Pack"(%9, %363) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %366 = "tf.Tile"(%312, %365) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %367 = "tf.Shape"(%366) {device = ""} : (tensor) -> tensor<2xi32> + %368 = "tf.StridedSlice"(%367, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %369 = "tf.Prod"(%368, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %370 = "tf.Pack"(%369) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %371 = "tf.Shape"(%366) {device = ""} : (tensor) -> tensor<2xi32> + %372 = "tf.StridedSlice"(%371, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %373 = "tf.Shape"(%366) {device = ""} : (tensor) -> tensor<2xi32> + %374 = "tf.StridedSlice"(%373, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %375 = "tf.ConcatV2"(%372, %370, %374, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %376 = "tf.Reshape"(%366, %375) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %377 = "tf.ExpandDims"(%361, %2) {device = ""} : (tensor, tensor) -> tensor + %378 = "tf.Less"(%364, %377) {device = ""} : (tensor, tensor) -> tensor + %379 = "tf.Reshape"(%378, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %380 = "tf.Where"(%379) {device = ""} : (tensor) -> tensor + %381 = "tf.Squeeze"(%380) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %382 = "tf.GatherV2"(%376, %381, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %383:2 = "tf.RaggedRange"(%359, %382, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %384 = "tf.If"(%261, %261, %257, %67) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_106180, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_106170} : (tensor, tensor, tensor, tensor) -> tensor + %385 = "tf.Identity"(%384) {device = ""} : (tensor) -> tensor + %386 = "tf.StridedSlice"(%62, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %387 = "tf.Equal"(%386, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %388 = "tf.Select"(%387, %257, %386) {device = ""} : (tensor, tensor, tensor) -> tensor + %389 = "tf.Pack"(%388) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %390 = "tf.StridedSlice"(%62, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> + %391 = "tf.StridedSlice"(%62, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %392 = "tf.ConcatV2"(%390, %389, %391, %16) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> + %393 = "tf.StridedSlice"(%392, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %394 = "tf.Equal"(%393, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %395 = "tf.StridedSlice"(%392, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %396 = "tf.StridedSlice"(%392, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %397 = "tf.Equal"(%396, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %398 = "tf.If"(%397, %397, %396, %336) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_106670, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_106660} : (tensor, tensor, tensor, tensor) -> tensor + %399 = "tf.Identity"(%398) {device = ""} : (tensor) -> tensor + %400 = "tf.If"(%394, %394, %336, %395) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_107030, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_107020} : (tensor, tensor, tensor, tensor) -> tensor + %401 = "tf.If"(%236, %236, %15, %232) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_false_111870, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_true_111860} : (tensor, tensor, tensor, tensor) -> tensor + %402 = "tf.Identity"(%401) {device = ""} : (tensor) -> tensor + %403 = "tf.Equal"(%232, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %404 = "tf.Select"(%403, %15, %232) {device = ""} : (tensor, tensor, tensor) -> tensor + %405 = "tf.Equal"(%404, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %406 = "tf.LogicalOr"(%405, %1) {device = ""} : (tensor, tensor) -> tensor + %407 = "tf.Equal"(%404, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %408 = "tf.LogicalOr"(%406, %407) {device = ""} : (tensor, tensor) -> tensor + %409 = "tf.Select"(%243, %404, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %410 = "tf.Pack"(%409, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %411 = "tf.StridedSlice"(%410, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %412 = "tf.Cast"(%411) {Truncate = false, device = ""} : (tensor) -> tensor + %413 = "tf.Reshape"(%412, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %414 = "tf.Pack"(%9, %413) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %415 = "tf.Tile"(%244, %414) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %416 = "tf.Mul"(%413, %247) {device = ""} : (tensor, tensor) -> tensor + %417 = "tf.Pack"(%416) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %418 = "tf.ConcatV2"(%246, %417, %248, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %419 = "tf.Reshape"(%415, %418) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %420 = "tf.Shape"(%419) {device = ""} : (tensor) -> tensor<1xi64> + %421 = "tf.StridedSlice"(%420, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %422 = "tf.Pack"(%411) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %423 = "tf.StridedSlice"(%419, %422, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %424 = "tf.Sub"(%421, %411) {device = ""} : (tensor, tensor) -> tensor + %425 = "tf.Pack"(%424) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %426 = "tf.StridedSlice"(%419, %13, %425, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %427:2 = "tf.RaggedRange"(%426, %423, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %428 = "tf.GatherV2"(%250, %427#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %429 = "tf.StridedSlice"(%410, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %430 = "tf.StridedSlice"(%410, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %431 = "tf.StridedSlice"(%410, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> + %432 = "tf.ConcatV2"(%430, %431, %16) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> + %433 = "tf.StridedSlice"(%410, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %434 = "tf.Mul"(%253, %433) {device = ""} : (tensor, tensor) -> tensor + %435 = "tf.Tile"(%434, %429) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %436 = "tf.Cumsum"(%435, %16) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %437 = "tf.ConcatV2"(%13, %436, %2) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %438 = "tf.Shape"(%437) {device = ""} : (tensor) -> tensor<1xi64> + %439 = "tf.StridedSlice"(%438, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %440 = "tf.Sub"(%439, %15) {device = ""} : (tensor, tensor) -> tensor + %441 = "tf.Equal"(%440, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %442 = "tf.LogicalOr"(%117, %441) {device = ""} : (tensor, tensor) -> tensor + %443 = "tf.Equal"(%440, %116) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %444 = "tf.LogicalOr"(%442, %443) {device = ""} : (tensor, tensor) -> tensor + %445 = "tf.StridedSlice"(%437, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %446 = "tf.StridedSlice"(%437, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %447 = "tf.Sub"(%445, %446) {device = ""} : (tensor, tensor) -> tensor + %448 = "tf.Shape"(%437) {device = ""} : (tensor) -> tensor<1xi64> + %449 = "tf.StridedSlice"(%448, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %450 = "tf.Sub"(%449, %15) {device = ""} : (tensor, tensor) -> tensor + %451 = "tf.Equal"(%450, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %452 = "tf.ExpandDims"(%437, %9) {device = ""} : (tensor, tensor) -> tensor + %453 = "tf.Shape"(%437) {device = ""} : (tensor) -> tensor<1xi32> + %454 = "tf.StridedSlice"(%453, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %455 = "tf.StridedSlice"(%453, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %456 = "tf.StridedSlice"(%453, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %457 = "tf.Select"(%1, %404, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %458 = "tf.Pack"(%457, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %459 = "tf.StridedSlice"(%458, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %460 = "tf.Cast"(%459) {Truncate = false, device = ""} : (tensor) -> tensor + %461 = "tf.Reshape"(%460, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %462 = "tf.Pack"(%9, %461) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %463 = "tf.Tile"(%3, %462) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %464 = "tf.Mul"(%461, %10) {device = ""} : (tensor, tensor) -> tensor + %465 = "tf.Pack"(%464) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %466 = "tf.ConcatV2"(%11, %465, %11, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %467 = "tf.Reshape"(%463, %466) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %468 = "tf.Shape"(%467) {device = ""} : (tensor) -> tensor<1xi64> + %469 = "tf.StridedSlice"(%468, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %470 = "tf.Pack"(%459) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %471 = "tf.StridedSlice"(%467, %470, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %472 = "tf.Sub"(%469, %459) {device = ""} : (tensor, tensor) -> tensor + %473 = "tf.Pack"(%472) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %474 = "tf.StridedSlice"(%467, %13, %473, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %475:2 = "tf.RaggedRange"(%474, %471, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %476 = "tf.GatherV2"(%13, %475#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %477 = "tf.GatherV2"(%14, %476, %16) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %478 = "tf.StridedSlice"(%458, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %479 = "tf.StridedSlice"(%458, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %480 = "tf.StridedSlice"(%458, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> + %481 = "tf.ConcatV2"(%479, %480, %16) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> + %482 = "tf.Tile"(%477, %481) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %483 = "tf.StridedSlice"(%458, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %484 = "tf.Mul"(%483, %14) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> + %485 = "tf.Tile"(%484, %478) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor + %486 = "tf.Cumsum"(%485, %16) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %487 = "tf.ConcatV2"(%13, %486, %2) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %488 = "tf.StridedSlice"(%487, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %489 = "tf.ExpandDims"(%488, %9) {device = ""} : (tensor, tensor) -> tensor + %490 = "tf.Shape"(%488) {device = ""} : (tensor) -> tensor<1xi32> + %491 = "tf.StridedSlice"(%490, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %492 = "tf.Pack"(%491) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %493 = "tf.StridedSlice"(%487, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %494 = "tf.ExpandDims"(%493, %9) {device = ""} : (tensor, tensor) -> tensor + %495 = "tf.Shape"(%493) {device = ""} : (tensor) -> tensor<1xi32> + %496 = "tf.StridedSlice"(%495, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %497 = "tf.Pack"(%496) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %498 = "tf.Equal"(%232, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %499 = "tf.Select"(%498, %404, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %500 = "tf.Cast"(%499) {Truncate = false, device = ""} : (tensor) -> tensor + %501 = "tf.Reshape"(%500, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %502 = "tf.Pack"(%9, %501) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %503 = "tf.Mul"(%501, %10) {device = ""} : (tensor, tensor) -> tensor + %504 = "tf.Pack"(%503) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %505 = "tf.ConcatV2"(%11, %504, %11, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %506 = "tf.Pack"(%499) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %507 = "tf.Pack"(%12, %232) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %508 = "tf.ExpandDims"(%507, %9) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> + %509 = "tf.Tile"(%508, %502) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %510 = "tf.Reshape"(%509, %505) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %511 = "tf.Shape"(%510) {device = ""} : (tensor) -> tensor<1xi64> + %512 = "tf.StridedSlice"(%511, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %513 = "tf.Sub"(%512, %499) {device = ""} : (tensor, tensor) -> tensor + %514 = "tf.Pack"(%513) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %515 = "tf.StridedSlice"(%510, %13, %514, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %516 = "tf.StridedSlice"(%510, %506, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %517:2 = "tf.RaggedRange"(%515, %516, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %518 = "tf.GatherV2"(%239, %517#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %519 = "tf.Cast"(%518) {Truncate = false, device = ""} : (tensor) -> tensor + %520 = "tf.BroadcastTo"(%519, %492) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %521 = "tf.Max"(%520, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %522 = "tf.Maximum"(%16, %521) {device = ""} : (tensor, tensor) -> tensor + %523 = "tf.Range"(%16, %522, %9) {device = ""} : (tensor, tensor, tensor) -> tensor + %524 = "tf.Pack"(%9, %522) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %525 = "tf.Tile"(%489, %524) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %526 = "tf.Shape"(%525) {device = ""} : (tensor) -> tensor<2xi32> + %527 = "tf.StridedSlice"(%526, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %528 = "tf.Prod"(%527, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %529 = "tf.Pack"(%528) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %530 = "tf.Shape"(%525) {device = ""} : (tensor) -> tensor<2xi32> + %531 = "tf.StridedSlice"(%530, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %532 = "tf.Shape"(%525) {device = ""} : (tensor) -> tensor<2xi32> + %533 = "tf.StridedSlice"(%532, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %534 = "tf.ConcatV2"(%531, %529, %533, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %535 = "tf.Reshape"(%525, %534) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %536 = "tf.ExpandDims"(%520, %2) {device = ""} : (tensor, tensor) -> tensor + %537 = "tf.Less"(%523, %536) {device = ""} : (tensor, tensor) -> tensor + %538 = "tf.Reshape"(%537, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %539 = "tf.Where"(%538) {device = ""} : (tensor) -> tensor + %540 = "tf.Squeeze"(%539) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %541 = "tf.GatherV2"(%535, %540, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %542 = "tf.Cast"(%518) {Truncate = false, device = ""} : (tensor) -> tensor + %543 = "tf.BroadcastTo"(%542, %497) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %544 = "tf.Max"(%543, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %545 = "tf.Maximum"(%16, %544) {device = ""} : (tensor, tensor) -> tensor + %546 = "tf.Range"(%16, %545, %9) {device = ""} : (tensor, tensor, tensor) -> tensor + %547 = "tf.Pack"(%9, %545) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %548 = "tf.Tile"(%494, %547) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %549 = "tf.Shape"(%548) {device = ""} : (tensor) -> tensor<2xi32> + %550 = "tf.StridedSlice"(%549, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %551 = "tf.Prod"(%550, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %552 = "tf.Pack"(%551) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %553 = "tf.Shape"(%548) {device = ""} : (tensor) -> tensor<2xi32> + %554 = "tf.StridedSlice"(%553, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %555 = "tf.Shape"(%548) {device = ""} : (tensor) -> tensor<2xi32> + %556 = "tf.StridedSlice"(%555, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %557 = "tf.ConcatV2"(%554, %552, %556, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %558 = "tf.Reshape"(%548, %557) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %559 = "tf.ExpandDims"(%543, %2) {device = ""} : (tensor, tensor) -> tensor + %560 = "tf.Less"(%546, %559) {device = ""} : (tensor, tensor) -> tensor + %561 = "tf.Reshape"(%560, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %562 = "tf.Where"(%561) {device = ""} : (tensor) -> tensor + %563 = "tf.Squeeze"(%562) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %564 = "tf.GatherV2"(%558, %563, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %565:2 = "tf.RaggedRange"(%541, %564, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %566 = "tf.GatherV2"(%482, %565#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %567 = "tf.If"(%408, %408, %404, %15) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_false_112940, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_true_112930} : (tensor, tensor, tensor, tensor) -> tensor + %568 = "tf.Identity"(%567) {device = ""} : (tensor) -> tensor + %569 = "tf.Select"(%1, %404, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %570 = "tf.Pack"(%569) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %571 = "tf.ConcatV2"(%0, %570, %14, %16) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> + %572 = "tf.StridedSlice"(%571, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %573 = "tf.Equal"(%572, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %574 = "tf.StridedSlice"(%571, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %575 = "tf.StridedSlice"(%571, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %576 = "tf.Equal"(%575, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %577 = "tf.If"(%576, %576, %575, %518) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_false_113430, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_true_113420} : (tensor, tensor, tensor, tensor) -> tensor + %578 = "tf.Identity"(%577) {device = ""} : (tensor) -> tensor + %579 = "tf.If"(%573, %573, %518, %574) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_false_113790, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_true_113780} : (tensor, tensor, tensor, tensor) -> tensor + %580 = "tf.Identity"(%579) {device = ""} : (tensor) -> tensor + %581 = "tf.If"(%444, %444, %116, %440) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_118470, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_118460} : (tensor, tensor, tensor, tensor) -> tensor + %582 = "tf.Identity"(%581) {device = ""} : (tensor) -> tensor + %583 = "tf.Equal"(%440, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %584 = "tf.Select"(%583, %116, %440) {device = ""} : (tensor, tensor, tensor) -> tensor + %585 = "tf.Equal"(%584, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %586 = "tf.LogicalOr"(%585, %119) {device = ""} : (tensor, tensor) -> tensor + %587 = "tf.Equal"(%118, %584) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %588 = "tf.LogicalOr"(%586, %587) {device = ""} : (tensor, tensor) -> tensor + %589 = "tf.Select"(%451, %584, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %590 = "tf.Pack"(%589, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %591 = "tf.StridedSlice"(%590, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %592 = "tf.Cast"(%591) {Truncate = false, device = ""} : (tensor) -> tensor + %593 = "tf.Reshape"(%592, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %594 = "tf.Pack"(%9, %593) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %595 = "tf.Tile"(%452, %594) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %596 = "tf.Mul"(%593, %455) {device = ""} : (tensor, tensor) -> tensor + %597 = "tf.Pack"(%596) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %598 = "tf.ConcatV2"(%454, %597, %456, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %599 = "tf.Reshape"(%595, %598) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %600 = "tf.Shape"(%599) {device = ""} : (tensor) -> tensor<1xi64> + %601 = "tf.StridedSlice"(%600, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %602 = "tf.Pack"(%591) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %603 = "tf.StridedSlice"(%599, %602, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %604 = "tf.Sub"(%601, %591) {device = ""} : (tensor, tensor) -> tensor + %605 = "tf.Pack"(%604) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %606 = "tf.StridedSlice"(%599, %13, %605, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %607:2 = "tf.RaggedRange"(%606, %603, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %608 = "tf.Select"(%124, %584, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %609 = "tf.Pack"(%608, %15) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %610 = "tf.StridedSlice"(%609, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %611 = "tf.Cast"(%610) {Truncate = false, device = ""} : (tensor) -> tensor + %612 = "tf.Reshape"(%611, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %613 = "tf.Pack"(%9, %612) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %614 = "tf.Tile"(%106, %613) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %615 = "tf.Mul"(%612, %109) {device = ""} : (tensor, tensor) -> tensor + %616 = "tf.Pack"(%615) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %617 = "tf.ConcatV2"(%108, %616, %110, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %618 = "tf.Reshape"(%614, %617) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %619 = "tf.Shape"(%618) {device = ""} : (tensor) -> tensor<1xi64> + %620 = "tf.StridedSlice"(%619, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %621 = "tf.Pack"(%610) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %622 = "tf.StridedSlice"(%618, %621, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %623 = "tf.Sub"(%620, %610) {device = ""} : (tensor, tensor) -> tensor + %624 = "tf.Pack"(%623) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %625 = "tf.StridedSlice"(%618, %13, %624, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %626:2 = "tf.RaggedRange"(%625, %622, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %627 = "tf.StridedSlice"(%609, %17, %18, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %628 = "tf.StridedSlice"(%609, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %629 = "tf.Mul"(%113, %628) {device = ""} : (tensor, tensor) -> tensor + %630 = "tf.Tile"(%629, %627) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %631 = "tf.Cumsum"(%630, %16) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %632 = "tf.ConcatV2"(%13, %631, %2) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %633 = "tf.StridedSlice"(%632, %17, %6, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %634 = "tf.ExpandDims"(%633, %9) {device = ""} : (tensor, tensor) -> tensor + %635 = "tf.Shape"(%633) {device = ""} : (tensor) -> tensor<1xi32> + %636 = "tf.StridedSlice"(%635, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %637 = "tf.Pack"(%636) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %638 = "tf.StridedSlice"(%632, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %639 = "tf.ExpandDims"(%638, %9) {device = ""} : (tensor, tensor) -> tensor + %640 = "tf.Shape"(%638) {device = ""} : (tensor) -> tensor<1xi32> + %641 = "tf.StridedSlice"(%640, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %642 = "tf.Pack"(%641) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %643 = "tf.Equal"(%440, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %644 = "tf.Select"(%643, %584, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %645 = "tf.Cast"(%644) {Truncate = false, device = ""} : (tensor) -> tensor + %646 = "tf.Reshape"(%645, %11) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %647 = "tf.Pack"(%9, %646) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %648 = "tf.Mul"(%646, %10) {device = ""} : (tensor, tensor) -> tensor + %649 = "tf.Pack"(%648) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %650 = "tf.ConcatV2"(%11, %649, %11, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %651 = "tf.Pack"(%644) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %652 = "tf.Pack"(%12, %440) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %653 = "tf.ExpandDims"(%652, %9) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> + %654 = "tf.Tile"(%653, %647) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %655 = "tf.Reshape"(%654, %650) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %656 = "tf.Shape"(%655) {device = ""} : (tensor) -> tensor<1xi64> + %657 = "tf.StridedSlice"(%656, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %658 = "tf.Sub"(%657, %644) {device = ""} : (tensor, tensor) -> tensor + %659 = "tf.Pack"(%658) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %660 = "tf.StridedSlice"(%655, %13, %659, %14) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %661 = "tf.StridedSlice"(%655, %651, %13, %14) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %662:2 = "tf.RaggedRange"(%660, %661, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %663 = "tf.GatherV2"(%447, %662#1, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %664 = "tf.Cast"(%663) {Truncate = false, device = ""} : (tensor) -> tensor + %665 = "tf.BroadcastTo"(%664, %637) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %666 = "tf.Max"(%665, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %667 = "tf.Maximum"(%16, %666) {device = ""} : (tensor, tensor) -> tensor + %668 = "tf.Range"(%16, %667, %9) {device = ""} : (tensor, tensor, tensor) -> tensor + %669 = "tf.Pack"(%9, %667) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %670 = "tf.Tile"(%634, %669) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %671 = "tf.Shape"(%670) {device = ""} : (tensor) -> tensor<2xi32> + %672 = "tf.StridedSlice"(%671, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %673 = "tf.Prod"(%672, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %674 = "tf.Pack"(%673) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %675 = "tf.Shape"(%670) {device = ""} : (tensor) -> tensor<2xi32> + %676 = "tf.StridedSlice"(%675, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %677 = "tf.Shape"(%670) {device = ""} : (tensor) -> tensor<2xi32> + %678 = "tf.StridedSlice"(%677, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %679 = "tf.ConcatV2"(%676, %674, %678, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %680 = "tf.Reshape"(%670, %679) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %681 = "tf.ExpandDims"(%665, %2) {device = ""} : (tensor, tensor) -> tensor + %682 = "tf.Less"(%668, %681) {device = ""} : (tensor, tensor) -> tensor + %683 = "tf.Reshape"(%682, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %684 = "tf.Where"(%683) {device = ""} : (tensor) -> tensor + %685 = "tf.Squeeze"(%684) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %686 = "tf.GatherV2"(%680, %685, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %687 = "tf.Cast"(%663) {Truncate = false, device = ""} : (tensor) -> tensor + %688 = "tf.BroadcastTo"(%687, %642) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %689 = "tf.Max"(%688, %17) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %690 = "tf.Maximum"(%16, %689) {device = ""} : (tensor, tensor) -> tensor + %691 = "tf.Range"(%16, %690, %9) {device = ""} : (tensor, tensor, tensor) -> tensor + %692 = "tf.Pack"(%9, %690) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %693 = "tf.Tile"(%639, %692) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %694 = "tf.Shape"(%693) {device = ""} : (tensor) -> tensor<2xi32> + %695 = "tf.StridedSlice"(%694, %17, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %696 = "tf.Prod"(%695, %17) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %697 = "tf.Pack"(%696) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %698 = "tf.Shape"(%693) {device = ""} : (tensor) -> tensor<2xi32> + %699 = "tf.StridedSlice"(%698, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %700 = "tf.Shape"(%693) {device = ""} : (tensor) -> tensor<2xi32> + %701 = "tf.StridedSlice"(%700, %7, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %702 = "tf.ConcatV2"(%699, %697, %701, %16) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %703 = "tf.Reshape"(%693, %702) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %704 = "tf.ExpandDims"(%688, %2) {device = ""} : (tensor, tensor) -> tensor + %705 = "tf.Less"(%691, %704) {device = ""} : (tensor, tensor) -> tensor + %706 = "tf.Reshape"(%705, %6) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %707 = "tf.Where"(%706) {device = ""} : (tensor) -> tensor + %708 = "tf.Squeeze"(%707) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %709 = "tf.GatherV2"(%703, %708, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %710:2 = "tf.RaggedRange"(%686, %709, %15) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %711 = "tf.If"(%588, %588, %584, %120) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_119540, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_119530} : (tensor, tensor, tensor, tensor) -> tensor + %712 = "tf.Identity"(%711) {device = ""} : (tensor) -> tensor + %713 = "tf.StridedSlice"(%115, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %714 = "tf.Equal"(%713, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %715 = "tf.Select"(%714, %584, %713) {device = ""} : (tensor, tensor, tensor) -> tensor + %716 = "tf.Pack"(%715) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %717 = "tf.StridedSlice"(%115, %17, %17, %18) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> + %718 = "tf.StridedSlice"(%115, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %719 = "tf.ConcatV2"(%717, %716, %718, %16) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> + %720 = "tf.StridedSlice"(%719, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %721 = "tf.Equal"(%720, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %722 = "tf.StridedSlice"(%719, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %723 = "tf.StridedSlice"(%719, %18, %7, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %724 = "tf.Equal"(%723, %15) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %725 = "tf.If"(%724, %724, %723, %663) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_120030, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_120020} : (tensor, tensor, tensor, tensor) -> tensor + %726 = "tf.Identity"(%725) {device = ""} : (tensor) -> tensor + %727 = "tf.If"(%721, %721, %663, %722) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_120390, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_120380} : (tensor, tensor, tensor, tensor) -> tensor + %728 = "tf.Identity"(%168) {device = ""} : (tensor) -> tensor + %729 = "tf.Identity"(%727) {device = ""} : (tensor) -> tensor + %730 = "tf.Identity"(%400) {device = ""} : (tensor) -> tensor + %731 = "tf.Shape"(%125#2) {device = ""} : (tensor) -> tensor<1xi32> + %732 = "tf.StridedSlice"(%731, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %733 = "tf.Cast"(%732) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> + %734 = "tf.Identity"(%733) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> + %735 = "tf.Shape"(%125#3) {device = ""} : (tensor) -> tensor<1xi32> + %736 = "tf.StridedSlice"(%735, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %737 = "tf.Cast"(%736) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> + %738 = "tf.Identity"(%737) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> + %739 = "tf.GatherV2"(%125#3, %428, %16) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %740 = "tf.Tile"(%739, %432) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %741 = "tf.Sub"(%740, %566) {device = ""} : (tensor, tensor) -> tensor + %742 = "tf.Shape"(%741) {device = ""} : (tensor) -> tensor<1xi32> + %743 = "tf.StridedSlice"(%742, %18, %17, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %744 = "tf.Cast"(%743) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> + %745 = "tf.Identity"(%744) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> + %746 = "tf.UnicodeEncode"(%125#0, %146) {Tsplits = i64, device = "", errors = "replace", output_encoding = "UTF-8", replacement_char = 65533 : i64} : (tensor, tensor) -> tensor + %747 = "tf.Identity"(%746) {device = ""} : (tensor) -> tensor + %748 = "tf.StridedSlice"(%19, %17, %18, %18) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %749 = "tf.AddV2"(%748, %15) {device = ""} : (tensor, tensor) -> tensor + %750 = "tf.Range"(%12, %749, %15) {device = ""} : (tensor, tensor, tensor) -> tensor + %751 = "tf.Mul"(%750, %15) {device = ""} : (tensor, tensor) -> tensor + %752 = "tf.Identity"(%751) {device = ""} : (tensor) -> tensor + return %747, %752, %728 : tensor, tensor, tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_99640(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Input tensors have incompatible shapes."> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedConcat/RaggedFromTensor/strided_slice_4:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedConcat/RaggedNRows/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_99630(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_100400(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_100390(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_100760(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_100750(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_101100(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_101090(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_101470(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_101460(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_101830(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_101820(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_102190(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RaggedNRows/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_102180(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_102540(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_102530(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_102900(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_102890(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_103240(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_103230(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_103610(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_103600(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_103970(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_103960(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_104310(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_104300(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_105110(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_105100(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_106180(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_106170(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_106670(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_106660(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_107030(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_107020(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_false_111870(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_true_111860(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_false_112940(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_true_112930(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_false_113430(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_true_113420(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_false_113790(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_true_113780(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_118470(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_118460(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_119540(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_119530(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_120030(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_120020(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_120390(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_120380(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} + + + +// CHECK: func @whitespace_tokenizer_rank2(%arg0: tensor {tf._user_specified_name = "input"}) -> (tensor, tensor, tensor) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape], tf.signature.is_stateful} { +// CHECK: %0:3 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor) -> (tensor, tensor, tensor) +// CHECK: return %0#0, %0#1, %0#2 : tensor, tensor, tensor + +func @whitespace_tokenizer_rank0(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64> + %1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64> + %2 = "tf.Const"() {value = dense : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %4 = "tf.Const"() {value = dense<[[0], [1]]> : tensor<2x1xi64>} : () -> tensor<2x1xi64> + %5 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %6 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %7 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %8 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %9 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> + %10 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %11 = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> tensor<1xi64> + %12 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> + %13 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %14 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %15 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %16 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %17 = "tf.If"(%2, %2, %13, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_3220, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_3210} : (tensor, tensor, tensor, tensor) -> tensor + %18 = "tf.Identity"(%17) {device = ""} : (tensor) -> tensor + %19 = "tf.Pack"(%arg0) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1x!tf.string> + %20 = "tf.StringLength"(%19) {device = "", unit = "BYTE"} : (tensor<1x!tf.string>) -> tensor<1xi32> + %21 = "tf.ExpandDims"(%20, %7) {device = ""} : (tensor<1xi32>, tensor) -> tensor<1x1xi32> + %22 = "tf.Cast"(%21) {Truncate = false, device = ""} : (tensor<1x1xi32>) -> tensor<1x1xi64> + %23 = "tf.Reshape"(%22, %12) {device = ""} : (tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64> + %24 = "tf.Reshape"(%19, %5) {device = ""} : (tensor<1x!tf.string>, tensor<1xi32>) -> tensor<1x!tf.string> + %25:3 = "tf.UnicodeDecodeWithOffsets"(%24) {Tsplits = i64, device = "", errors = "replace", input_encoding = "UTF-8", replace_control_characters = false, replacement_char = 65533 : i64} : (tensor<1x!tf.string>) -> (tensor<2xi64>, tensor, tensor) + %26 = "tf.StridedSlice"(%25#0, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %27 = "tf.AddV2"(%26, %13) {device = ""} : (tensor<1xi64>, tensor) -> tensor<1xi64> + %28 = "tf.StridedSlice"(%25#0, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %29 = "tf.Minimum"(%27, %28) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> + %30:2 = "tf.RaggedRange"(%29, %28, %13) {T = i64, Tsplits = i64, device = ""} : (tensor<1xi64>, tensor<1xi64>, tensor) -> (tensor<2xi64>, tensor) + %31 = "tf.StridedSlice"(%30#0, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %32 = "tf.AddV2"(%31, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> + %33 = "tf.ConcatV2"(%30#0, %32, %14) {device = ""} : (tensor<2xi64>, tensor<1xi64>, tensor) -> tensor<3xi64> + %34 = "tf.GatherV2"(%25#2, %30#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %35 = "tf.ConcatV2"(%34, %23, %14) {device = ""} : (tensor, tensor<1xi64>, tensor) -> tensor + %36:2 = "tf.RaggedGather"(%33, %35, %0) {OUTPUT_RAGGED_RANK = 1 : i64, PARAMS_RAGGED_RANK = 1 : i64, Tindices = i64, Tsplits = i64, Tvalues = i64, device = ""} : (tensor<3xi64>, tensor, tensor<2xi64>) -> (tensor, tensor) + %37:5 = "tf.WhitespaceTokenizeWithOffsets"(%25#1, %25#0) {Tsplits = i64, device = ""} : (tensor, tensor<2xi64>) -> (tensor, tensor, tensor, tensor, tensor) + %38 = "tf.StridedSlice"(%37#1, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %39 = "tf.Equal"(%38, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %40 = "tf.All"(%39, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %41 = "tf.If"(%40, %40, %38, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_3980, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_3970} : (tensor, tensor, tensor, tensor) -> tensor + %42 = "tf.Identity"(%41) {device = ""} : (tensor) -> tensor + %43 = "tf.StridedSlice"(%37#1, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %44 = "tf.StridedSlice"(%37#1, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %45 = "tf.Sub"(%43, %44) {device = ""} : (tensor, tensor) -> tensor + %46 = "tf.LessEqual"(%10, %45) {device = ""} : (tensor, tensor) -> tensor + %47 = "tf.All"(%46, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %48 = "tf.If"(%47, %47, %45) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_4340, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_4330} : (tensor, tensor, tensor) -> tensor + %49 = "tf.Identity"(%48) {device = ""} : (tensor) -> tensor + %50 = "tf.Identity"(%37#1) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %51 = "tf.StridedSlice"(%50, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %52 = "tf.Shape"(%37#0) {device = ""} : (tensor) -> tensor<1xi64> + %53 = "tf.StridedSlice"(%52, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %54 = "tf.Equal"(%51, %53) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %55 = "tf.All"(%54, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %56 = "tf.If"(%55, %55, %51, %53) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_4680, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_4670} : (tensor, tensor, tensor, tensor) -> tensor + %57 = "tf.Identity"(%56) {device = ""} : (tensor) -> tensor + %58 = "tf.Identity"(%50) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %59 = "tf.Shape"(%58) {device = ""} : (tensor) -> tensor<1xi64> + %60 = "tf.StridedSlice"(%59, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %61 = "tf.Sub"(%60, %13) {device = ""} : (tensor, tensor) -> tensor + %62 = "tf.StridedSlice"(%37#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %63 = "tf.Equal"(%62, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %64 = "tf.All"(%63, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %65 = "tf.If"(%64, %64, %62, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_5050, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_5040} : (tensor, tensor, tensor, tensor) -> tensor + %66 = "tf.Identity"(%65) {device = ""} : (tensor) -> tensor + %67 = "tf.StridedSlice"(%37#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %68 = "tf.StridedSlice"(%37#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %69 = "tf.Sub"(%67, %68) {device = ""} : (tensor, tensor) -> tensor + %70 = "tf.LessEqual"(%10, %69) {device = ""} : (tensor, tensor) -> tensor + %71 = "tf.All"(%70, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %72 = "tf.If"(%71, %71, %69) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_5410, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_5400} : (tensor, tensor, tensor) -> tensor + %73 = "tf.Identity"(%72) {device = ""} : (tensor) -> tensor + %74 = "tf.Identity"(%37#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %75 = "tf.StridedSlice"(%74, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %76 = "tf.Equal"(%75, %61) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %77 = "tf.All"(%76, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %78 = "tf.If"(%77, %77, %75, %61) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_5770, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_5760} : (tensor, tensor, tensor, tensor) -> tensor + %79 = "tf.Identity"(%78) {device = ""} : (tensor) -> tensor + %80 = "tf.Identity"(%74) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %81 = "tf.StridedSlice"(%37#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %82 = "tf.Equal"(%81, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %83 = "tf.All"(%82, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %84 = "tf.If"(%83, %83, %81, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6120, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6110} : (tensor, tensor, tensor, tensor) -> tensor + %85 = "tf.Identity"(%84) {device = ""} : (tensor) -> tensor + %86 = "tf.StridedSlice"(%37#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %87 = "tf.StridedSlice"(%37#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %88 = "tf.Sub"(%86, %87) {device = ""} : (tensor, tensor) -> tensor + %89 = "tf.LessEqual"(%10, %88) {device = ""} : (tensor, tensor) -> tensor + %90 = "tf.All"(%89, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %91 = "tf.If"(%90, %90, %88) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_6480, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_6470} : (tensor, tensor, tensor) -> tensor + %92 = "tf.Identity"(%91) {device = ""} : (tensor) -> tensor + %93 = "tf.Identity"(%37#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %94 = "tf.StridedSlice"(%93, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %95 = "tf.Shape"(%37#2) {device = ""} : (tensor) -> tensor<1xi64> + %96 = "tf.StridedSlice"(%95, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %97 = "tf.Equal"(%94, %96) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %98 = "tf.All"(%97, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %99 = "tf.If"(%98, %98, %94, %96) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6820, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6810} : (tensor, tensor, tensor, tensor) -> tensor + %100 = "tf.Identity"(%99) {device = ""} : (tensor) -> tensor + %101 = "tf.Identity"(%93) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %102 = "tf.Shape"(%101) {device = ""} : (tensor) -> tensor<1xi64> + %103 = "tf.StridedSlice"(%102, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %104 = "tf.Sub"(%103, %13) {device = ""} : (tensor, tensor) -> tensor + %105 = "tf.Equal"(%104, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %106 = "tf.LogicalOr"(%105, %2) {device = ""} : (tensor, tensor) -> tensor + %107 = "tf.Equal"(%104, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %108 = "tf.LogicalOr"(%106, %107) {device = ""} : (tensor, tensor) -> tensor + %109 = "tf.StridedSlice"(%101, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %110 = "tf.StridedSlice"(%101, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %111 = "tf.Sub"(%109, %110) {device = ""} : (tensor, tensor) -> tensor + %112 = "tf.Shape"(%101) {device = ""} : (tensor) -> tensor<1xi64> + %113 = "tf.StridedSlice"(%112, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %114 = "tf.Sub"(%113, %13) {device = ""} : (tensor, tensor) -> tensor + %115 = "tf.Equal"(%114, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %116 = "tf.ExpandDims"(%101, %7) {device = ""} : (tensor, tensor) -> tensor + %117 = "tf.Shape"(%101) {device = ""} : (tensor) -> tensor<1xi32> + %118 = "tf.StridedSlice"(%117, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %119 = "tf.StridedSlice"(%117, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %120 = "tf.StridedSlice"(%117, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %121 = "tf.StridedSlice"(%37#4, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %122 = "tf.Equal"(%121, %10) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %123 = "tf.All"(%122, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %124 = "tf.If"(%123, %123, %121, %10) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7190, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7180} : (tensor, tensor, tensor, tensor) -> tensor + %125 = "tf.Identity"(%124) {device = ""} : (tensor) -> tensor + %126 = "tf.StridedSlice"(%37#4, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %127 = "tf.StridedSlice"(%37#4, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %128 = "tf.Sub"(%126, %127) {device = ""} : (tensor, tensor) -> tensor + %129 = "tf.LessEqual"(%10, %128) {device = ""} : (tensor, tensor) -> tensor + %130 = "tf.All"(%129, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %131 = "tf.If"(%130, %130, %128) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_7550, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_7540} : (tensor, tensor, tensor) -> tensor + %132 = "tf.Identity"(%131) {device = ""} : (tensor) -> tensor + %133 = "tf.Identity"(%37#4) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %134 = "tf.StridedSlice"(%133, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %135 = "tf.Shape"(%37#3) {device = ""} : (tensor) -> tensor<1xi64> + %136 = "tf.StridedSlice"(%135, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %137 = "tf.Equal"(%134, %136) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %138 = "tf.All"(%137, %9) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %139 = "tf.If"(%138, %138, %134, %136) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7890, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7880} : (tensor, tensor, tensor, tensor) -> tensor + %140 = "tf.Identity"(%139) {device = ""} : (tensor) -> tensor + %141 = "tf.Identity"(%133) {_class = ["loc:@WhitespaceTokenize/WhitespaceTokenize/WhitespaceTokenizeWithOffsets"], device = ""} : (tensor) -> tensor + %142 = "tf.Shape"(%141) {device = ""} : (tensor) -> tensor<1xi64> + %143 = "tf.StridedSlice"(%142, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %144 = "tf.Sub"(%143, %13) {device = ""} : (tensor, tensor) -> tensor + %145 = "tf.Equal"(%144, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %146 = "tf.LogicalOr"(%145, %2) {device = ""} : (tensor, tensor) -> tensor + %147 = "tf.Equal"(%144, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %148 = "tf.LogicalOr"(%146, %147) {device = ""} : (tensor, tensor) -> tensor + %149 = "tf.StridedSlice"(%141, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %150 = "tf.StridedSlice"(%141, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %151 = "tf.Sub"(%149, %150) {device = ""} : (tensor, tensor) -> tensor + %152 = "tf.Shape"(%141) {device = ""} : (tensor) -> tensor<1xi64> + %153 = "tf.StridedSlice"(%152, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %154 = "tf.Sub"(%153, %13) {device = ""} : (tensor, tensor) -> tensor + %155 = "tf.Equal"(%154, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %156 = "tf.ExpandDims"(%141, %7) {device = ""} : (tensor, tensor) -> tensor + %157 = "tf.Shape"(%141) {device = ""} : (tensor) -> tensor<1xi32> + %158 = "tf.StridedSlice"(%157, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %159 = "tf.StridedSlice"(%157, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %160 = "tf.StridedSlice"(%157, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %161 = "tf.StridedSlice"(%141, %5, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %162 = "tf.Range"(%10, %161, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %163 = "tf.StridedSlice"(%141, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %164 = "tf.StridedSlice"(%141, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %165 = "tf.Sub"(%163, %164) {device = ""} : (tensor, tensor) -> tensor + %166 = "tf.If"(%108, %108, %13, %104) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_8690, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_8680} : (tensor, tensor, tensor, tensor) -> tensor + %167 = "tf.Identity"(%166) {device = ""} : (tensor) -> tensor + %168 = "tf.Equal"(%104, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %169 = "tf.Select"(%168, %13, %104) {device = ""} : (tensor, tensor, tensor) -> tensor + %170 = "tf.Equal"(%169, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %171 = "tf.LogicalOr"(%170, %2) {device = ""} : (tensor, tensor) -> tensor + %172 = "tf.Equal"(%169, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %173 = "tf.LogicalOr"(%171, %172) {device = ""} : (tensor, tensor) -> tensor + %174 = "tf.Select"(%115, %169, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %175 = "tf.Pack"(%174, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %176 = "tf.StridedSlice"(%175, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %177 = "tf.Cast"(%176) {Truncate = false, device = ""} : (tensor) -> tensor + %178 = "tf.Reshape"(%177, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %179 = "tf.Pack"(%7, %178) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %180 = "tf.Tile"(%116, %179) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %181 = "tf.Mul"(%178, %119) {device = ""} : (tensor, tensor) -> tensor + %182 = "tf.Pack"(%181) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %183 = "tf.ConcatV2"(%118, %182, %120, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %184 = "tf.Reshape"(%180, %183) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %185 = "tf.Shape"(%184) {device = ""} : (tensor) -> tensor<1xi64> + %186 = "tf.StridedSlice"(%185, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %187 = "tf.Pack"(%176) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %188 = "tf.StridedSlice"(%184, %187, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %189 = "tf.Sub"(%186, %176) {device = ""} : (tensor, tensor) -> tensor + %190 = "tf.Pack"(%189) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %191 = "tf.StridedSlice"(%184, %11, %190, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %192:2 = "tf.RaggedRange"(%191, %188, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %193 = "tf.Select"(%2, %169, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %194 = "tf.Pack"(%193, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %195 = "tf.StridedSlice"(%194, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %196 = "tf.Cast"(%195) {Truncate = false, device = ""} : (tensor) -> tensor + %197 = "tf.Reshape"(%196, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %198 = "tf.Pack"(%7, %197) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %199 = "tf.Tile"(%4, %198) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %200 = "tf.Mul"(%197, %8) {device = ""} : (tensor, tensor) -> tensor + %201 = "tf.Pack"(%200) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %202 = "tf.ConcatV2"(%9, %201, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %203 = "tf.Reshape"(%199, %202) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %204 = "tf.Shape"(%203) {device = ""} : (tensor) -> tensor<1xi64> + %205 = "tf.StridedSlice"(%204, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %206 = "tf.Pack"(%195) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %207 = "tf.StridedSlice"(%203, %206, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %208 = "tf.Sub"(%205, %195) {device = ""} : (tensor, tensor) -> tensor + %209 = "tf.Pack"(%208) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %210 = "tf.StridedSlice"(%203, %11, %209, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %211:2 = "tf.RaggedRange"(%210, %207, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %212 = "tf.StridedSlice"(%194, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %213 = "tf.StridedSlice"(%194, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %214 = "tf.Mul"(%213, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> + %215 = "tf.Tile"(%214, %212) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor + %216 = "tf.Cumsum"(%215, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %217 = "tf.ConcatV2"(%11, %216, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %218 = "tf.StridedSlice"(%217, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %219 = "tf.ExpandDims"(%218, %7) {device = ""} : (tensor, tensor) -> tensor + %220 = "tf.Shape"(%218) {device = ""} : (tensor) -> tensor<1xi32> + %221 = "tf.StridedSlice"(%220, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %222 = "tf.Pack"(%221) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %223 = "tf.StridedSlice"(%217, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %224 = "tf.ExpandDims"(%223, %7) {device = ""} : (tensor, tensor) -> tensor + %225 = "tf.Shape"(%223) {device = ""} : (tensor) -> tensor<1xi32> + %226 = "tf.StridedSlice"(%225, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %227 = "tf.Pack"(%226) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %228 = "tf.Equal"(%104, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %229 = "tf.Select"(%228, %169, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %230 = "tf.Cast"(%229) {Truncate = false, device = ""} : (tensor) -> tensor + %231 = "tf.Reshape"(%230, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %232 = "tf.Pack"(%7, %231) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %233 = "tf.Mul"(%231, %8) {device = ""} : (tensor, tensor) -> tensor + %234 = "tf.Pack"(%233) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %235 = "tf.ConcatV2"(%9, %234, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %236 = "tf.Pack"(%229) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %237 = "tf.Pack"(%10, %104) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %238 = "tf.ExpandDims"(%237, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> + %239 = "tf.Tile"(%238, %232) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %240 = "tf.Reshape"(%239, %235) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %241 = "tf.Shape"(%240) {device = ""} : (tensor) -> tensor<1xi64> + %242 = "tf.StridedSlice"(%241, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %243 = "tf.Sub"(%242, %229) {device = ""} : (tensor, tensor) -> tensor + %244 = "tf.Pack"(%243) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %245 = "tf.StridedSlice"(%240, %11, %244, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %246 = "tf.StridedSlice"(%240, %236, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %247:2 = "tf.RaggedRange"(%245, %246, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %248 = "tf.GatherV2"(%111, %247#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %249 = "tf.Cast"(%248) {Truncate = false, device = ""} : (tensor) -> tensor + %250 = "tf.BroadcastTo"(%249, %222) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %251 = "tf.Max"(%250, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %252 = "tf.Maximum"(%14, %251) {device = ""} : (tensor, tensor) -> tensor + %253 = "tf.Range"(%14, %252, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %254 = "tf.Pack"(%7, %252) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %255 = "tf.Tile"(%219, %254) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %256 = "tf.Shape"(%255) {device = ""} : (tensor) -> tensor<2xi32> + %257 = "tf.StridedSlice"(%256, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %258 = "tf.Prod"(%257, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %259 = "tf.Pack"(%258) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %260 = "tf.Shape"(%255) {device = ""} : (tensor) -> tensor<2xi32> + %261 = "tf.StridedSlice"(%260, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %262 = "tf.Shape"(%255) {device = ""} : (tensor) -> tensor<2xi32> + %263 = "tf.StridedSlice"(%262, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %264 = "tf.ConcatV2"(%261, %259, %263, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %265 = "tf.Reshape"(%255, %264) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %266 = "tf.ExpandDims"(%250, %3) {device = ""} : (tensor, tensor) -> tensor + %267 = "tf.Less"(%253, %266) {device = ""} : (tensor, tensor) -> tensor + %268 = "tf.Reshape"(%267, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %269 = "tf.Where"(%268) {device = ""} : (tensor) -> tensor + %270 = "tf.Squeeze"(%269) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %271 = "tf.GatherV2"(%265, %270, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %272 = "tf.Cast"(%248) {Truncate = false, device = ""} : (tensor) -> tensor + %273 = "tf.BroadcastTo"(%272, %227) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %274 = "tf.Max"(%273, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %275 = "tf.Maximum"(%14, %274) {device = ""} : (tensor, tensor) -> tensor + %276 = "tf.Range"(%14, %275, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %277 = "tf.Pack"(%7, %275) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %278 = "tf.Tile"(%224, %277) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %279 = "tf.Shape"(%278) {device = ""} : (tensor) -> tensor<2xi32> + %280 = "tf.StridedSlice"(%279, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %281 = "tf.Prod"(%280, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %282 = "tf.Pack"(%281) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %283 = "tf.Shape"(%278) {device = ""} : (tensor) -> tensor<2xi32> + %284 = "tf.StridedSlice"(%283, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %285 = "tf.Shape"(%278) {device = ""} : (tensor) -> tensor<2xi32> + %286 = "tf.StridedSlice"(%285, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %287 = "tf.ConcatV2"(%284, %282, %286, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %288 = "tf.Reshape"(%278, %287) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %289 = "tf.ExpandDims"(%273, %3) {device = ""} : (tensor, tensor) -> tensor + %290 = "tf.Less"(%276, %289) {device = ""} : (tensor, tensor) -> tensor + %291 = "tf.Reshape"(%290, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %292 = "tf.Where"(%291) {device = ""} : (tensor) -> tensor + %293 = "tf.Squeeze"(%292) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %294 = "tf.GatherV2"(%288, %293, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %295:2 = "tf.RaggedRange"(%271, %294, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %296 = "tf.If"(%173, %173, %169, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_9760, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_9750} : (tensor, tensor, tensor, tensor) -> tensor + %297 = "tf.Identity"(%296) {device = ""} : (tensor) -> tensor + %298 = "tf.Select"(%2, %169, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %299 = "tf.Pack"(%298) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %300 = "tf.ConcatV2"(%1, %299, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> + %301 = "tf.StridedSlice"(%300, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %302 = "tf.Equal"(%301, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %303 = "tf.StridedSlice"(%300, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %304 = "tf.StridedSlice"(%300, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %305 = "tf.Equal"(%304, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %306 = "tf.If"(%305, %305, %304, %248) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_10250, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_10240} : (tensor, tensor, tensor, tensor) -> tensor + %307 = "tf.Identity"(%306) {device = ""} : (tensor) -> tensor + %308 = "tf.If"(%302, %302, %248, %303) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_10610, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_10600} : (tensor, tensor, tensor, tensor) -> tensor + %309 = "tf.If"(%148, %148, %13, %144) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_false_15310, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_true_15300} : (tensor, tensor, tensor, tensor) -> tensor + %310 = "tf.Identity"(%309) {device = ""} : (tensor) -> tensor + %311 = "tf.Equal"(%144, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %312 = "tf.Select"(%311, %13, %144) {device = ""} : (tensor, tensor, tensor) -> tensor + %313 = "tf.Equal"(%312, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %314 = "tf.LogicalOr"(%313, %2) {device = ""} : (tensor, tensor) -> tensor + %315 = "tf.Equal"(%312, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %316 = "tf.LogicalOr"(%314, %315) {device = ""} : (tensor, tensor) -> tensor + %317 = "tf.Select"(%155, %312, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %318 = "tf.Pack"(%317, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %319 = "tf.StridedSlice"(%318, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %320 = "tf.Cast"(%319) {Truncate = false, device = ""} : (tensor) -> tensor + %321 = "tf.Reshape"(%320, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %322 = "tf.Pack"(%7, %321) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %323 = "tf.Tile"(%156, %322) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %324 = "tf.Mul"(%321, %159) {device = ""} : (tensor, tensor) -> tensor + %325 = "tf.Pack"(%324) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %326 = "tf.ConcatV2"(%158, %325, %160, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %327 = "tf.Reshape"(%323, %326) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %328 = "tf.Shape"(%327) {device = ""} : (tensor) -> tensor<1xi64> + %329 = "tf.StridedSlice"(%328, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %330 = "tf.Pack"(%319) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %331 = "tf.StridedSlice"(%327, %330, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %332 = "tf.Sub"(%329, %319) {device = ""} : (tensor, tensor) -> tensor + %333 = "tf.Pack"(%332) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %334 = "tf.StridedSlice"(%327, %11, %333, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %335:2 = "tf.RaggedRange"(%334, %331, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %336 = "tf.GatherV2"(%162, %335#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %337 = "tf.StridedSlice"(%318, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %338 = "tf.StridedSlice"(%318, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %339 = "tf.StridedSlice"(%318, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> + %340 = "tf.ConcatV2"(%338, %339, %14) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> + %341 = "tf.StridedSlice"(%318, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %342 = "tf.Mul"(%165, %341) {device = ""} : (tensor, tensor) -> tensor + %343 = "tf.Tile"(%342, %337) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %344 = "tf.Cumsum"(%343, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %345 = "tf.ConcatV2"(%11, %344, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %346 = "tf.Shape"(%345) {device = ""} : (tensor) -> tensor<1xi64> + %347 = "tf.StridedSlice"(%346, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %348 = "tf.Sub"(%347, %13) {device = ""} : (tensor, tensor) -> tensor + %349 = "tf.Equal"(%348, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %350 = "tf.LogicalOr"(%349, %2) {device = ""} : (tensor, tensor) -> tensor + %351 = "tf.Equal"(%348, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %352 = "tf.LogicalOr"(%350, %351) {device = ""} : (tensor, tensor) -> tensor + %353 = "tf.StridedSlice"(%345, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %354 = "tf.StridedSlice"(%345, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %355 = "tf.Sub"(%353, %354) {device = ""} : (tensor, tensor) -> tensor + %356 = "tf.Shape"(%345) {device = ""} : (tensor) -> tensor<1xi64> + %357 = "tf.StridedSlice"(%356, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %358 = "tf.Sub"(%357, %13) {device = ""} : (tensor, tensor) -> tensor + %359 = "tf.Equal"(%358, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %360 = "tf.ExpandDims"(%345, %7) {device = ""} : (tensor, tensor) -> tensor + %361 = "tf.Shape"(%345) {device = ""} : (tensor) -> tensor<1xi32> + %362 = "tf.StridedSlice"(%361, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %363 = "tf.StridedSlice"(%361, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %364 = "tf.StridedSlice"(%361, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %365 = "tf.Select"(%2, %312, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %366 = "tf.Pack"(%365, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %367 = "tf.StridedSlice"(%366, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %368 = "tf.Cast"(%367) {Truncate = false, device = ""} : (tensor) -> tensor + %369 = "tf.Reshape"(%368, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %370 = "tf.Pack"(%7, %369) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %371 = "tf.Tile"(%4, %370) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %372 = "tf.Mul"(%369, %8) {device = ""} : (tensor, tensor) -> tensor + %373 = "tf.Pack"(%372) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %374 = "tf.ConcatV2"(%9, %373, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %375 = "tf.Reshape"(%371, %374) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %376 = "tf.Shape"(%375) {device = ""} : (tensor) -> tensor<1xi64> + %377 = "tf.StridedSlice"(%376, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %378 = "tf.Pack"(%367) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %379 = "tf.StridedSlice"(%375, %378, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %380 = "tf.Sub"(%377, %367) {device = ""} : (tensor, tensor) -> tensor + %381 = "tf.Pack"(%380) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %382 = "tf.StridedSlice"(%375, %11, %381, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %383:2 = "tf.RaggedRange"(%382, %379, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %384 = "tf.GatherV2"(%11, %383#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %385 = "tf.GatherV2"(%12, %384, %14) {batch_dims = 0 : i64, device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %386 = "tf.StridedSlice"(%366, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %387 = "tf.StridedSlice"(%366, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %388 = "tf.StridedSlice"(%366, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi64> + %389 = "tf.ConcatV2"(%387, %388, %14) {device = ""} : (tensor<1xi64>, tensor<0xi64>, tensor) -> tensor<1xi64> + %390 = "tf.Tile"(%385, %389) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %391 = "tf.StridedSlice"(%366, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %392 = "tf.Mul"(%391, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> + %393 = "tf.Tile"(%392, %386) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor + %394 = "tf.Cumsum"(%393, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %395 = "tf.ConcatV2"(%11, %394, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %396 = "tf.StridedSlice"(%395, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %397 = "tf.ExpandDims"(%396, %7) {device = ""} : (tensor, tensor) -> tensor + %398 = "tf.Shape"(%396) {device = ""} : (tensor) -> tensor<1xi32> + %399 = "tf.StridedSlice"(%398, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %400 = "tf.Pack"(%399) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %401 = "tf.StridedSlice"(%395, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %402 = "tf.ExpandDims"(%401, %7) {device = ""} : (tensor, tensor) -> tensor + %403 = "tf.Shape"(%401) {device = ""} : (tensor) -> tensor<1xi32> + %404 = "tf.StridedSlice"(%403, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %405 = "tf.Pack"(%404) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %406 = "tf.Equal"(%144, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %407 = "tf.Select"(%406, %312, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %408 = "tf.Cast"(%407) {Truncate = false, device = ""} : (tensor) -> tensor + %409 = "tf.Reshape"(%408, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %410 = "tf.Pack"(%7, %409) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %411 = "tf.Mul"(%409, %8) {device = ""} : (tensor, tensor) -> tensor + %412 = "tf.Pack"(%411) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %413 = "tf.ConcatV2"(%9, %412, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %414 = "tf.Pack"(%407) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %415 = "tf.Pack"(%10, %144) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %416 = "tf.ExpandDims"(%415, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> + %417 = "tf.Tile"(%416, %410) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %418 = "tf.Reshape"(%417, %413) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %419 = "tf.Shape"(%418) {device = ""} : (tensor) -> tensor<1xi64> + %420 = "tf.StridedSlice"(%419, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %421 = "tf.Sub"(%420, %407) {device = ""} : (tensor, tensor) -> tensor + %422 = "tf.Pack"(%421) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %423 = "tf.StridedSlice"(%418, %11, %422, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %424 = "tf.StridedSlice"(%418, %414, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %425:2 = "tf.RaggedRange"(%423, %424, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %426 = "tf.GatherV2"(%151, %425#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %427 = "tf.Cast"(%426) {Truncate = false, device = ""} : (tensor) -> tensor + %428 = "tf.BroadcastTo"(%427, %400) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %429 = "tf.Max"(%428, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %430 = "tf.Maximum"(%14, %429) {device = ""} : (tensor, tensor) -> tensor + %431 = "tf.Range"(%14, %430, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %432 = "tf.Pack"(%7, %430) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %433 = "tf.Tile"(%397, %432) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %434 = "tf.Shape"(%433) {device = ""} : (tensor) -> tensor<2xi32> + %435 = "tf.StridedSlice"(%434, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %436 = "tf.Prod"(%435, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %437 = "tf.Pack"(%436) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %438 = "tf.Shape"(%433) {device = ""} : (tensor) -> tensor<2xi32> + %439 = "tf.StridedSlice"(%438, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %440 = "tf.Shape"(%433) {device = ""} : (tensor) -> tensor<2xi32> + %441 = "tf.StridedSlice"(%440, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %442 = "tf.ConcatV2"(%439, %437, %441, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %443 = "tf.Reshape"(%433, %442) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %444 = "tf.ExpandDims"(%428, %3) {device = ""} : (tensor, tensor) -> tensor + %445 = "tf.Less"(%431, %444) {device = ""} : (tensor, tensor) -> tensor + %446 = "tf.Reshape"(%445, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %447 = "tf.Where"(%446) {device = ""} : (tensor) -> tensor + %448 = "tf.Squeeze"(%447) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %449 = "tf.GatherV2"(%443, %448, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %450 = "tf.Cast"(%426) {Truncate = false, device = ""} : (tensor) -> tensor + %451 = "tf.BroadcastTo"(%450, %405) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %452 = "tf.Max"(%451, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %453 = "tf.Maximum"(%14, %452) {device = ""} : (tensor, tensor) -> tensor + %454 = "tf.Range"(%14, %453, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %455 = "tf.Pack"(%7, %453) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %456 = "tf.Tile"(%402, %455) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %457 = "tf.Shape"(%456) {device = ""} : (tensor) -> tensor<2xi32> + %458 = "tf.StridedSlice"(%457, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %459 = "tf.Prod"(%458, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %460 = "tf.Pack"(%459) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %461 = "tf.Shape"(%456) {device = ""} : (tensor) -> tensor<2xi32> + %462 = "tf.StridedSlice"(%461, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %463 = "tf.Shape"(%456) {device = ""} : (tensor) -> tensor<2xi32> + %464 = "tf.StridedSlice"(%463, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %465 = "tf.ConcatV2"(%462, %460, %464, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %466 = "tf.Reshape"(%456, %465) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %467 = "tf.ExpandDims"(%451, %3) {device = ""} : (tensor, tensor) -> tensor + %468 = "tf.Less"(%454, %467) {device = ""} : (tensor, tensor) -> tensor + %469 = "tf.Reshape"(%468, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %470 = "tf.Where"(%469) {device = ""} : (tensor) -> tensor + %471 = "tf.Squeeze"(%470) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %472 = "tf.GatherV2"(%466, %471, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %473:2 = "tf.RaggedRange"(%449, %472, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %474 = "tf.GatherV2"(%390, %473#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %475 = "tf.If"(%316, %316, %312, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_false_16380, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_true_16370} : (tensor, tensor, tensor, tensor) -> tensor + %476 = "tf.Identity"(%475) {device = ""} : (tensor) -> tensor + %477 = "tf.Select"(%2, %312, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %478 = "tf.Pack"(%477) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %479 = "tf.ConcatV2"(%1, %478, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> + %480 = "tf.StridedSlice"(%479, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %481 = "tf.Equal"(%480, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %482 = "tf.StridedSlice"(%479, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %483 = "tf.StridedSlice"(%479, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %484 = "tf.Equal"(%483, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %485 = "tf.If"(%484, %484, %483, %426) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_false_16870, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_true_16860} : (tensor, tensor, tensor, tensor) -> tensor + %486 = "tf.Identity"(%485) {device = ""} : (tensor) -> tensor + %487 = "tf.If"(%481, %481, %426, %482) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_false_17230, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_true_17220} : (tensor, tensor, tensor, tensor) -> tensor + %488 = "tf.Identity"(%487) {device = ""} : (tensor) -> tensor + %489 = "tf.If"(%352, %352, %13, %348) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_21910, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_21900} : (tensor, tensor, tensor, tensor) -> tensor + %490 = "tf.Identity"(%489) {device = ""} : (tensor) -> tensor + %491 = "tf.Equal"(%348, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %492 = "tf.Select"(%491, %13, %348) {device = ""} : (tensor, tensor, tensor) -> tensor + %493 = "tf.Equal"(%492, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %494 = "tf.LogicalOr"(%493, %2) {device = ""} : (tensor, tensor) -> tensor + %495 = "tf.Equal"(%492, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %496 = "tf.LogicalOr"(%494, %495) {device = ""} : (tensor, tensor) -> tensor + %497 = "tf.Select"(%359, %492, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %498 = "tf.Pack"(%497, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %499 = "tf.StridedSlice"(%498, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %500 = "tf.Cast"(%499) {Truncate = false, device = ""} : (tensor) -> tensor + %501 = "tf.Reshape"(%500, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %502 = "tf.Pack"(%7, %501) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %503 = "tf.Tile"(%360, %502) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %504 = "tf.Mul"(%501, %363) {device = ""} : (tensor, tensor) -> tensor + %505 = "tf.Pack"(%504) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %506 = "tf.ConcatV2"(%362, %505, %364, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %507 = "tf.Reshape"(%503, %506) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %508 = "tf.Shape"(%507) {device = ""} : (tensor) -> tensor<1xi64> + %509 = "tf.StridedSlice"(%508, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %510 = "tf.Pack"(%499) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %511 = "tf.StridedSlice"(%507, %510, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %512 = "tf.Sub"(%509, %499) {device = ""} : (tensor, tensor) -> tensor + %513 = "tf.Pack"(%512) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %514 = "tf.StridedSlice"(%507, %11, %513, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %515:2 = "tf.RaggedRange"(%514, %511, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %516 = "tf.Select"(%2, %492, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %517 = "tf.Pack"(%516, %13) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %518 = "tf.StridedSlice"(%517, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %519 = "tf.Cast"(%518) {Truncate = false, device = ""} : (tensor) -> tensor + %520 = "tf.Reshape"(%519, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %521 = "tf.Pack"(%7, %520) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %522 = "tf.Tile"(%4, %521) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %523 = "tf.Mul"(%520, %8) {device = ""} : (tensor, tensor) -> tensor + %524 = "tf.Pack"(%523) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %525 = "tf.ConcatV2"(%9, %524, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %526 = "tf.Reshape"(%522, %525) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %527 = "tf.Shape"(%526) {device = ""} : (tensor) -> tensor<1xi64> + %528 = "tf.StridedSlice"(%527, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %529 = "tf.Pack"(%518) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %530 = "tf.StridedSlice"(%526, %529, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %531 = "tf.Sub"(%528, %518) {device = ""} : (tensor, tensor) -> tensor + %532 = "tf.Pack"(%531) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %533 = "tf.StridedSlice"(%526, %11, %532, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %534:2 = "tf.RaggedRange"(%533, %530, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %535 = "tf.StridedSlice"(%517, %15, %16, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %536 = "tf.StridedSlice"(%517, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %537 = "tf.Mul"(%536, %12) {device = ""} : (tensor, tensor<1xi64>) -> tensor<1xi64> + %538 = "tf.Tile"(%537, %535) {device = ""} : (tensor<1xi64>, tensor<1xi64>) -> tensor + %539 = "tf.Cumsum"(%538, %14) {device = "", exclusive = false, reverse = false} : (tensor, tensor) -> tensor + %540 = "tf.ConcatV2"(%11, %539, %3) {device = ""} : (tensor<1xi64>, tensor, tensor) -> tensor + %541 = "tf.StridedSlice"(%540, %15, %5, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %542 = "tf.ExpandDims"(%541, %7) {device = ""} : (tensor, tensor) -> tensor + %543 = "tf.Shape"(%541) {device = ""} : (tensor) -> tensor<1xi32> + %544 = "tf.StridedSlice"(%543, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %545 = "tf.Pack"(%544) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %546 = "tf.StridedSlice"(%540, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %547 = "tf.ExpandDims"(%546, %7) {device = ""} : (tensor, tensor) -> tensor + %548 = "tf.Shape"(%546) {device = ""} : (tensor) -> tensor<1xi32> + %549 = "tf.StridedSlice"(%548, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %550 = "tf.Pack"(%549) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %551 = "tf.Equal"(%348, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %552 = "tf.Select"(%551, %492, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %553 = "tf.Cast"(%552) {Truncate = false, device = ""} : (tensor) -> tensor + %554 = "tf.Reshape"(%553, %9) {device = ""} : (tensor, tensor<0xi32>) -> tensor + %555 = "tf.Pack"(%7, %554) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %556 = "tf.Mul"(%554, %8) {device = ""} : (tensor, tensor) -> tensor + %557 = "tf.Pack"(%556) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %558 = "tf.ConcatV2"(%9, %557, %9, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %559 = "tf.Pack"(%552) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %560 = "tf.Pack"(%10, %348) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi64> + %561 = "tf.ExpandDims"(%560, %7) {device = ""} : (tensor<2xi64>, tensor) -> tensor<2x1xi64> + %562 = "tf.Tile"(%561, %555) {device = ""} : (tensor<2x1xi64>, tensor<2xi32>) -> tensor<2x?xi64> + %563 = "tf.Reshape"(%562, %558) {device = ""} : (tensor<2x?xi64>, tensor<1xi32>) -> tensor + %564 = "tf.Shape"(%563) {device = ""} : (tensor) -> tensor<1xi64> + %565 = "tf.StridedSlice"(%564, %15, %16, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %566 = "tf.Sub"(%565, %552) {device = ""} : (tensor, tensor) -> tensor + %567 = "tf.Pack"(%566) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %568 = "tf.StridedSlice"(%563, %11, %567, %12) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %569 = "tf.StridedSlice"(%563, %559, %11, %12) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + %570:2 = "tf.RaggedRange"(%568, %569, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %571 = "tf.GatherV2"(%355, %570#1, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %572 = "tf.Cast"(%571) {Truncate = false, device = ""} : (tensor) -> tensor + %573 = "tf.BroadcastTo"(%572, %545) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %574 = "tf.Max"(%573, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %575 = "tf.Maximum"(%14, %574) {device = ""} : (tensor, tensor) -> tensor + %576 = "tf.Range"(%14, %575, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %577 = "tf.Pack"(%7, %575) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %578 = "tf.Tile"(%542, %577) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %579 = "tf.Shape"(%578) {device = ""} : (tensor) -> tensor<2xi32> + %580 = "tf.StridedSlice"(%579, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %581 = "tf.Prod"(%580, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %582 = "tf.Pack"(%581) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %583 = "tf.Shape"(%578) {device = ""} : (tensor) -> tensor<2xi32> + %584 = "tf.StridedSlice"(%583, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %585 = "tf.Shape"(%578) {device = ""} : (tensor) -> tensor<2xi32> + %586 = "tf.StridedSlice"(%585, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %587 = "tf.ConcatV2"(%584, %582, %586, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %588 = "tf.Reshape"(%578, %587) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %589 = "tf.ExpandDims"(%573, %3) {device = ""} : (tensor, tensor) -> tensor + %590 = "tf.Less"(%576, %589) {device = ""} : (tensor, tensor) -> tensor + %591 = "tf.Reshape"(%590, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %592 = "tf.Where"(%591) {device = ""} : (tensor) -> tensor + %593 = "tf.Squeeze"(%592) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %594 = "tf.GatherV2"(%588, %593, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %595 = "tf.Cast"(%571) {Truncate = false, device = ""} : (tensor) -> tensor + %596 = "tf.BroadcastTo"(%595, %550) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %597 = "tf.Max"(%596, %15) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %598 = "tf.Maximum"(%14, %597) {device = ""} : (tensor, tensor) -> tensor + %599 = "tf.Range"(%14, %598, %7) {device = ""} : (tensor, tensor, tensor) -> tensor + %600 = "tf.Pack"(%7, %598) {axis = 0 : i64, device = ""} : (tensor, tensor) -> tensor<2xi32> + %601 = "tf.Tile"(%547, %600) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %602 = "tf.Shape"(%601) {device = ""} : (tensor) -> tensor<2xi32> + %603 = "tf.StridedSlice"(%602, %15, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %604 = "tf.Prod"(%603, %15) {device = "", keep_dims = false} : (tensor<2xi32>, tensor<1xi32>) -> tensor + %605 = "tf.Pack"(%604) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> + %606 = "tf.Shape"(%601) {device = ""} : (tensor) -> tensor<2xi32> + %607 = "tf.StridedSlice"(%606, %15, %15, %16) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %608 = "tf.Shape"(%601) {device = ""} : (tensor) -> tensor<2xi32> + %609 = "tf.StridedSlice"(%608, %6, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %610 = "tf.ConcatV2"(%607, %605, %609, %14) {device = ""} : (tensor<0xi32>, tensor<1xi32>, tensor<0xi32>, tensor) -> tensor<1xi32> + %611 = "tf.Reshape"(%601, %610) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %612 = "tf.ExpandDims"(%596, %3) {device = ""} : (tensor, tensor) -> tensor + %613 = "tf.Less"(%599, %612) {device = ""} : (tensor, tensor) -> tensor + %614 = "tf.Reshape"(%613, %5) {device = ""} : (tensor, tensor<1xi32>) -> tensor + %615 = "tf.Where"(%614) {device = ""} : (tensor) -> tensor + %616 = "tf.Squeeze"(%615) {device = "", squeeze_dims = [1]} : (tensor) -> tensor + %617 = "tf.GatherV2"(%611, %616, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %618:2 = "tf.RaggedRange"(%594, %617, %13) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %619 = "tf.If"(%496, %496, %492, %13) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_22980, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_22970} : (tensor, tensor, tensor, tensor) -> tensor + %620 = "tf.Identity"(%619) {device = ""} : (tensor) -> tensor + %621 = "tf.Select"(%2, %492, %13) {device = ""} : (tensor, tensor, tensor) -> tensor + %622 = "tf.Pack"(%621) {axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi64> + %623 = "tf.ConcatV2"(%1, %622, %12, %14) {device = ""} : (tensor<0xi64>, tensor<1xi64>, tensor<1xi64>, tensor) -> tensor<2xi64> + %624 = "tf.StridedSlice"(%623, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %625 = "tf.Equal"(%624, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %626 = "tf.StridedSlice"(%623, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %627 = "tf.StridedSlice"(%623, %16, %6, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %628 = "tf.Equal"(%627, %13) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %629 = "tf.If"(%628, %628, %627, %571) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_23470, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_23460} : (tensor, tensor, tensor, tensor) -> tensor + %630 = "tf.Identity"(%629) {device = ""} : (tensor) -> tensor + %631 = "tf.If"(%625, %625, %571, %626) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = "", else_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_23830, is_stateless = false, output_shapes = [#tf.shape<>], then_branch = @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23820} : (tensor, tensor, tensor, tensor) -> tensor + %632 = "tf.Identity"(%631) {device = ""} : (tensor) -> tensor + %633 = "tf.Identity"(%308) {device = ""} : (tensor) -> tensor + %634 = "tf.Shape"(%37#2) {device = ""} : (tensor) -> tensor<1xi32> + %635 = "tf.StridedSlice"(%634, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %636 = "tf.Cast"(%635) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> + %637 = "tf.Identity"(%636) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> + %638 = "tf.Shape"(%37#3) {device = ""} : (tensor) -> tensor<1xi32> + %639 = "tf.StridedSlice"(%638, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %640 = "tf.Cast"(%639) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> + %641 = "tf.Identity"(%640) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> + %642 = "tf.GatherV2"(%37#3, %336, %14) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %643 = "tf.Tile"(%642, %340) {device = ""} : (tensor, tensor<1xi64>) -> tensor + %644 = "tf.Sub"(%643, %474) {device = ""} : (tensor, tensor) -> tensor + %645 = "tf.Shape"(%644) {device = ""} : (tensor) -> tensor<1xi32> + %646 = "tf.StridedSlice"(%645, %16, %15, %16) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xi32> + %647 = "tf.Cast"(%646) {Truncate = false, device = ""} : (tensor<0xi32>) -> tensor<0xi64> + %648 = "tf.Identity"(%647) {device = ""} : (tensor<0xi64>) -> tensor<0xi64> + %649 = "tf.UnicodeEncode"(%37#0, %58) {Tsplits = i64, device = "", errors = "replace", output_encoding = "UTF-8", replacement_char = 65533 : i64} : (tensor, tensor) -> tensor + %650 = "tf.Identity"(%649) {device = ""} : (tensor) -> tensor + return %650 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_false_3220(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Input tensors have incompatible shapes."> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedConcat/RaggedFromTensor/Const:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedConcat/RaggedNRows/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedConcat_assert_equal_1_Assert_AssertGuard_true_3210(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_3980(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_3970(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_4340(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_4330(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_4680(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_4670(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_5050(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_5040(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_5410(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_5400(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_5770(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RaggedNRows/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_5760(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6120(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6110(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_6480(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_6470(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_6820(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_1/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_1_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_6810(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7190(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7180(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_7550(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_7540(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_7890(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (WhitespaceTokenize/WhitespaceTokenize/RaggedFromNestedRowSplits_2/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedFromNestedRowSplits_2_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_7880(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_false_8690(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_AssertGuard_true_8680(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_false_9760(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_1_AssertGuard_true_9750(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_false_10250(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_2_AssertGuard_true_10240(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_false_10610(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_Assert_3_AssertGuard_true_10600(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_false_15310(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_Assert_AssertGuard_true_15300(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_false_16380(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_Assert_1_AssertGuard_true_16370(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_false_16870(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_Assert_2_AssertGuard_true_16860(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_false_17230(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_Assert_3_AssertGuard_true_17220(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_false_21910(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_AssertGuard_true_21900(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_false_22980(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_1_AssertGuard_true_22970(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_false_23470(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_2_AssertGuard_true_23460(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_false_23830(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Unable to broadcast: dimension size mismatch in dimension"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"lengths="> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"dim_size="> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23820(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} + +// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} { +// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor) -> tensor +// CHECK: return %0 : tensor + +func @ngrams(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor attributes {tf._input_shapes = [#tf.shape], tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = " ", width = 2 : i64}>} { + %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<[0, -1]> : tensor<2xi32>} : () -> tensor<2xi32> + %2 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + %3 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> tensor<2xi32> + %4 = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32> + %5 = "tf.StridedSlice"(%arg0, %3, %1, %4) {begin_mask = 0 : i64, device = "", ellipsis_mask = 1 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor + %6 = "tf.StridedSlice"(%arg0, %2, %3, %4) {begin_mask = 0 : i64, device = "", ellipsis_mask = 1 : i64, end_mask = 2 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor + %7 = "tf.Pack"(%5, %6) {axis = -1 : i64, device = ""} : (tensor, tensor) -> tensor + %8 = "tf.ReduceJoin"(%7, %0) {device = "", keep_dims = false, separator = " "} : (tensor, tensor) -> tensor + %9 = "tf.Identity"(%8) {device = ""} : (tensor) -> tensor + return %9 : tensor +} + +// CHECK: func @ngrams(%arg0: tensor {tf._user_specified_name = "input"}) -> tensor attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = " ", width = 2 : i64}>, tf._input_shapes = [#tf.shape]} { +// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F72000120006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E383F040104FF152D0204141404082401"> : tensor<78xi8>} : (tensor) -> tensor +// CHECK: return %0 : tensor +// CHECK: } + +func @ngrams_ragged_rank_2(%arg0: tensor {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor {tf._user_specified_name = "args_1"}) -> (tensor, tensor<3xi64>, tensor) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape, #tf.shape<3>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %4 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %5 = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> + %6 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32> + %7 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %8 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %9 = "tf.StridedSlice"(%arg1, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %10 = "tf.Equal"(%9, %4) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %11 = "tf.All"(%10, %5) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %12 = "tf.StridedSlice"(%arg1, %8, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> + %13 = "tf.StridedSlice"(%arg1, %7, %6, %8) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> + %14 = "tf.Sub"(%12, %13) {device = ""} : (tensor<2xi64>, tensor<2xi64>) -> tensor<2xi64> + %15 = "tf.LessEqual"(%4, %14) {device = ""} : (tensor, tensor<2xi64>) -> tensor<2xi1> + %16 = "tf.All"(%15, %7) {device = "", keep_dims = false} : (tensor<2xi1>, tensor<1xi32>) -> tensor + %17 = "tf.StridedSlice"(%arg2, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %18 = "tf.Equal"(%17, %4) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %19 = "tf.All"(%18, %5) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %20 = "tf.IfRegion"(%19) ( { + %72 = "std.call"(%19, %17, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770} : (tensor, tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }, { + %72 = "std.call"(%19, %17, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780} : (tensor, tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + %21 = "tf.Identity"(%20) {device = ""} : (tensor) -> tensor + %22 = "tf.StridedSlice"(%arg2, %8, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %23 = "tf.StridedSlice"(%arg2, %7, %6, %8) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %24 = "tf.Sub"(%22, %23) {device = ""} : (tensor, tensor) -> tensor + %25 = "tf.LessEqual"(%4, %24) {device = ""} : (tensor, tensor) -> tensor + %26 = "tf.All"(%25, %7) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %27 = "tf.IfRegion"(%26) ( { + %72 = "std.call"(%26, %24) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130} : (tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }, { + %72 = "std.call"(%26, %24) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140} : (tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + %28 = "tf.Identity"(%27) {device = ""} : (tensor) -> tensor + %29 = "tf.Identity"(%arg2) {_class = ["loc:@args_1"], device = ""} : (tensor) -> tensor + %30 = "tf.StridedSlice"(%29, %6, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %31 = "tf.Shape"(%arg0) {device = ""} : (tensor) -> tensor<1xi64> + %32 = "tf.StridedSlice"(%31, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %33 = "tf.Equal"(%30, %32) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %34 = "tf.All"(%33, %5) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %35 = "tf.IfRegion"(%34) ( { + %72 = "std.call"(%34, %30, %32) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500} : (tensor, tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }, { + %72 = "std.call"(%34, %30, %32) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510} : (tensor, tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + %36 = "tf.Identity"(%35) {device = ""} : (tensor) -> tensor + %37 = "tf.Identity"(%29) {_class = ["loc:@args_1"], device = ""} : (tensor) -> tensor + %38 = "tf.StridedSlice"(%37, %7, %6, %8) {begin_mask = 1 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %39 = "tf.StridedSlice"(%37, %8, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 1 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %40 = "tf.Minimum"(%38, %39) {device = ""} : (tensor, tensor) -> tensor + %41 = "tf.AddV2"(%39, %1) {device = ""} : (tensor, tensor) -> tensor + %42 = "tf.Maximum"(%41, %38) {device = ""} : (tensor, tensor) -> tensor + %43:2 = "tf.RaggedRange"(%40, %42, %3) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %44 = "tf.GatherV2"(%arg0, %43#1, %2) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %45 = "tf.AddV2"(%38, %3) {device = ""} : (tensor, tensor) -> tensor + %46 = "tf.Minimum"(%45, %39) {device = ""} : (tensor, tensor) -> tensor + %47:2 = "tf.RaggedRange"(%46, %39, %3) {T = i64, Tsplits = i64, device = ""} : (tensor, tensor, tensor) -> (tensor, tensor) + %48 = "tf.Equal"(%43#0, %47#0) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %49 = "tf.All"(%48, %7) {device = "", keep_dims = false} : (tensor, tensor<1xi32>) -> tensor + %50 = "tf.GatherV2"(%arg0, %47#1, %2) {batch_dims = 0 : i64, device = ""} : (tensor, tensor, tensor) -> tensor + %51 = "tf.Shape"(%37) {device = ""} : (tensor) -> tensor<1xi64> + %52 = "tf.StridedSlice"(%51, %7, %8, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<1xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %53 = "tf.Sub"(%52, %3) {device = ""} : (tensor, tensor) -> tensor + %54 = "tf.IfRegion"(%11) ( { + %72 = "std.call"(%11, %9, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900} : (tensor, tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }, { + %72 = "std.call"(%11, %9, %4) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910} : (tensor, tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + %55 = "tf.Identity"(%54) {device = ""} : (tensor) -> tensor + %56 = "tf.IfRegion"(%16) ( { + %72 = "std.call"(%16, %14) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260} : (tensor, tensor<2xi64>) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }, { + %72 = "std.call"(%16, %14) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270} : (tensor, tensor<2xi64>) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + %57 = "tf.Identity"(%56) {device = ""} : (tensor) -> tensor + %58 = "tf.Identity"(%arg1) {_class = ["loc:@args_0"], device = ""} : (tensor<3xi64>) -> tensor<3xi64> + %59 = "tf.StridedSlice"(%58, %6, %7, %8) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %60 = "tf.Equal"(%59, %53) {device = "", incompatible_shape_error = true} : (tensor, tensor) -> tensor + %61 = "tf.All"(%60, %5) {device = "", keep_dims = false} : (tensor, tensor<0xi32>) -> tensor + %62 = "tf.IfRegion"(%61) ( { + %72 = "std.call"(%61, %59, %53) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650} : (tensor, tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }, { + %72 = "std.call"(%61, %59, %53) {callee = @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660} : (tensor, tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + %63 = "tf.IfRegion"(%49) ( { + %72 = "std.call"(%49, %43#0, %47#0) {callee = @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330} : (tensor, tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }, { + %72 = "std.call"(%49, %43#0, %47#0) {callee = @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340} : (tensor, tensor, tensor) -> tensor + "tf.Yield"(%72) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + %64 = "tf.Identity"(%43#0) {device = ""} : (tensor) -> tensor + %65 = "tf.Identity"(%63) {device = ""} : (tensor) -> tensor + %66 = "tf.Pack"(%44, %50) {axis = 1 : i64, device = ""} : (tensor, tensor) -> tensor + %67 = "tf.ReduceJoin"(%66, %0) {device = "", keep_dims = false, separator = ""} : (tensor, tensor) -> tensor + %68 = "tf.Identity"(%67) {device = ""} : (tensor) -> tensor + %69 = "tf.Identity"(%62) {device = ""} : (tensor) -> tensor + %70 = "tf.Identity"(%58) {_class = ["loc:@args_0"], device = ""} : (tensor<3xi64>) -> tensor<3xi64> + %71 = "tf.Identity"(%70) {device = ""} : (tensor<3xi64>) -> tensor<3xi64> + return %68, %71, %64 : tensor, tensor<3xi64>, tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor +"tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/Const:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor<2xi64>) -> () + %3 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %4 = "tf.Identity"(%3) {device = ""} : (tensor) -> tensor + return %4 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RaggedNRows/sub:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape]} { + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + return %1 : tensor +} +func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape, #tf.shape], tf.signature.is_stateful} { + %0 = "tf.Const"() {value = dense<"Inputs must have identical ragged splits"> : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x (NGrams/SlidingWindow/RaggedGetItem/RaggedRange:0) = "> : tensor} : () -> tensor + %3 = "tf.Const"() {value = dense<"y (NGrams/SlidingWindow/RaggedGetItem_1/RaggedRange:0) = "> : tensor} : () -> tensor + "tf.Assert"(%arg0, %0, %1, %2, %arg1, %3, %arg2) {device = "", summarize = 3 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> () + %4 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor +} +// CHECK: func @ngrams_ragged_rank_2(%arg0: tensor {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor {tf._user_specified_name = "args_1"}) -> (tensor, tensor<3xi64>, tensor) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape, #tf.shape<3>, #tf.shape], tf.signature.is_stateful} { +// CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F720000006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E373E040104FF152C0204141404082401"> : tensor<77xi8>} : (tensor, tensor<3xi64>, tensor) -> (tensor, tensor<3xi64>, tensor) +// CHECK: return %0#0, %0#1, %0#2 : tensor, tensor<3xi64>, tensor \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir index 1a61bc3f517..1ebe912284b 100644 --- a/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir +++ b/tensorflow/compiler/mlir/lite/tests/lower-static-tensor-list.mlir @@ -277,6 +277,45 @@ func @tensorlistWhileCond(%arg0: tensor, %arg1: tensor) -> ten // CHECK: return %[[RESULT]] : tensor } +// CHECK-LABEL: func @tensorlistWhileRegion +func @tensorlistWhileRegion(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { + %cst = constant dense<3> : tensor<1xi32> + %cst_0 = constant dense<0> : tensor + %cst_1 = constant dense<-1> : tensor + %0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor>> + // CHECK: "tf.WhileRegion" + %1:2 = "tf.WhileRegion"(%cst_0, %0) ({ + ^bb0(%carg0: tensor, %carg1: tensor): + %cst_2 = constant dense<2> : tensor + %1 = "tf.Less"(%carg0, %cst_2) : (tensor, tensor) -> tensor + "tf.Yield"(%1) : (tensor) -> () + + // verify condition types + // CHECK: ^bb0(%[[CARG0:.*]]: tensor, %[[CARG1:.*]]: tensor<*xf32>): + // CHECK: %[[COND:.*]] = "tf.Less"(%[[CARG0]], {{.*}}) : (tensor, tensor) -> tensor + // CHECK: "tf.Yield"(%[[COND]]) : (tensor) -> () + + }, + { + ^bb0(%barg0: tensor, %barg1: tensor): + %1 = "tf.TensorListLength"(%barg1) : (tensor) -> tensor + "tf.Yield"(%1, %barg1) : (tensor, tensor) -> () + + // verify body types + // CHECK: ^bb0(%[[BARG0:.*]]: tensor, %[[BARG1:.*]]: tensor<*xf32>): + // CHECK-NOT: tensor + // CHECK: %[[LEN:.*]] = "tf.Gather" + // CHECK-NOT: tensor + // CHECK: "tf.Yield"(%[[LEN]], %[[BARG1]]) : (tensor, tensor<*xf32>) -> () + + }) {is_stateless = false} : (tensor, tensor>>) -> (tensor, tensor>>) + // make sure the variant types in input/output have been updated + // CHECK: {is_stateless = false} : (tensor, tensor<2x3xf32>) -> (tensor, tensor<*xf32>) + %2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor>>, tensor) -> tensor<*xf32> + // CHECK: return %0#1 : tensor<*xf32> + return %2 : tensor<*xf32> +} + func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor) -> tensor { %0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor>> %1 = "tf.TensorListResize"(%0, %arg2) : (tensor>>, tensor) -> tensor>> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir new file mode 100644 index 00000000000..a5e6d4aabb5 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/flex_op_with_complex128.mlir @@ -0,0 +1,66 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s + +func @main(tensor<4xcomplex>, tensor<4xcomplex>) -> tensor<4xcomplex> { +^bb0(%arg0: tensor<4xcomplex>, %arg1: tensor<4xcomplex>): +// CHECK: { +// CHECK-NEXT: version: 3, +// CHECK-NEXT: operator_codes: [ { +// CHECK-NEXT: builtin_code: CUSTOM, +// CHECK-NEXT: custom_code: "FlexAdd" +// CHECK-NEXT: } ], +// CHECK-NEXT: subgraphs: [ { +// CHECK-NEXT: tensors: [ { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: type: COMPLEX128, +// CHECK-NEXT: buffer: 1, +// CHECK-NEXT: name: "arg0", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: type: COMPLEX128, +// CHECK-NEXT: buffer: 2, +// CHECK-NEXT: name: "arg1", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: }, { +// CHECK-NEXT: shape: [ 4 ], +// CHECK-NEXT: type: COMPLEX128, +// CHECK-NEXT: buffer: 3, +// CHECK-NEXT: name: "add", +// CHECK-NEXT: quantization: { +// CHECK-EMPTY: +// CHECK-NEXT: } +// CHECK-NEXT: } ], +// CHECK-NEXT: inputs: [ 0, 1 ], +// CHECK-NEXT: outputs: [ 2 ], +// CHECK-NEXT: operators: [ { +// CHECK-NEXT: inputs: [ 0, 1 ], +// CHECK-NEXT: outputs: [ 2 ], +// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 18, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: name: "main" +// CHECK-NEXT: } ], +// CHECK-NEXT: description: "MLIR Converted.", +// CHECK-NEXT: buffers: [ { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-EMPTY: +// CHECK-NEXT: }, { +// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ] +// CHECK-NEXT: } ], +// CHECK-NEXT: metadata: [ { +// CHECK-NEXT: name: "min_runtime_version", +// CHECK-NEXT: buffer: 4 +// CHECK-NEXT: } ] +// CHECK-NEXT:} + + %0 = "tf.Add"(%arg0, %arg1) : (tensor<4xcomplex>, tensor<4xcomplex>) -> tensor<4xcomplex> loc("add") + return %0 : tensor<4xcomplex> +} diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 5f434e954c8..7ef6997f938 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -598,6 +598,16 @@ func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform, %arg1: tensor<1x64x64x32xf32>, %arg2: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) { + %0, %1, %2, %3 = "tfl.custom_tf"(%arg0, %arg1, %arg2) ({ + %4, %5, %6, %7 = "tf.TFLite_Detection_PostProcess"(%arg0, %arg1, %arg2) {_output_quantized = true, _output_types = [f32, f32, f32, f32], _support_output_type_float_in_quantized_op = true, detections_per_class = 100 : i64, device = "", h_scale = 5.000000e+00 : f32, max_classes_per_detection = 1 : i64, max_detections = 20 : i64, nms_iou_threshold = 6.000000e-01 : f32, nms_score_threshold = 3.000000e-01 : f32, num_classes = 90 : i64, use_regular_nms = false, w_scale = 5.000000e+00 : f32, x_scale = 1.000000e+01 : f32, y_scale = 1.000000e+01 : f32} : (tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) + "tfl.yield"(%4, %5, %6, %7) : (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) -> () + }) : (tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) + return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32> +} + +// ----- + func @testMaxPoolingWithArgMax2D(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) { // custom op for "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) %0, %1 = "tfl.custom"(%arg0) {custom_option = opaque<"tfl", "0x01000000020000000200000002000000020000000000000000000000000000000000000000000000"> : tensor<40xi8>, custom_code = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) @@ -1238,6 +1248,13 @@ func @testSpaceToBatchND(%arg0 : tensor<1x4x4x3xf32>, %arg1 : tensor<2xi32>, %ar // ----- +func @testBatchMatmulQuant(%arg0 : tensor<1x4x384x32x!quant.uniform>, %arg1 : tensor<1x4x384x32x!quant.uniform>) -> tensor<1x4x384x384x!quant.uniform> { + // CHECK: "tfl.batch_matmul"(%arg0, %arg1) + %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32x!quant.uniform>, tensor<1x4x384x32x!quant.uniform>) -> tensor<1x4x384x384x!quant.uniform> + return %0 : tensor<1x4x384x384x!quant.uniform> +} +// ----- + func @testConcat(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<2x2xi32> { // CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 7861eb1ec6b..7923c82ba92 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -400,6 +400,32 @@ func @FuseFullyConnectedReshapeAddConst(%arg0: tensor<40x37xf32>, %arg1: tensor< // FOLD: return %[[fc]] } +// CHECK-LABEL: @FuseFullyConnectedReshapeAddConstWithActivation +// FOLD-LABEL: @FuseFullyConnectedReshapeAddConstWithActivation +func @FuseFullyConnectedReshapeAddConstWithActivation(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> { + %cst = constant dense<3.0> : tensor<40x40xf32> + %cst2 = constant dense<2.0> : tensor<40xf32> + %shape1 = constant dense<[1, 40, 40]> : tensor<3xi32> + %shape2 = constant dense<[40, 40]> : tensor<2xi32> + + %0 = "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, tensor<40x40xf32>) -> (tensor<40x40xf32>) + %1 = "tfl.reshape"(%0, %shape1) : (tensor<40x40xf32>, tensor<3xi32>) -> tensor<1x40x40xf32> + %2 = "tfl.add"(%1, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x40x40xf32>, tensor<40xf32>) -> tensor<1x40x40xf32> + %3 = "tfl.reshape"(%2, %shape2) : (tensor<1x40x40xf32>, tensor<2xi32>) -> tensor<40x40xf32> + + return %3 : tensor<40x40xf32> + + // CHECK: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32> + // CHECK: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} + // CHECK: %[[rs1:.*]] = "tfl.reshape"(%[[fc]] + // CHECK: %[[rs2:.*]] = "tfl.reshape"(%[[rs1]] + // CHECK: return %[[rs2]] + + // FOLD: %[[cst:.*]] = constant dense<5.000000e+00> : tensor<40x40xf32> + // FOLD: %[[fc:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[cst]]) {fused_activation_function = "RELU6", keep_num_dims = false, weights_format = "DEFAULT"} + // FOLD: return %[[fc]] +} + // CHECK-LABEL: @NotReorderReshapeAddIfNotBroadcastableAfter func @NotReorderReshapeAddIfNotBroadcastableAfter(%arg0: tensor<40x10x4xf32>) -> tensor<40x40xf32> { %cst = constant dense<2.0> : tensor<40xf32> @@ -829,6 +855,15 @@ func @doNotConvertNonTrivialTransposeToReshape(%arg0: tensor<6x6x256x1xf32>) -> // CHECK: return %[[RESULT]] } +// CHECK-LABEL: Relu +func @Relu(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %cst = constant dense<0.0> : tensor + %0 = "tfl.maximum"(%arg0, %cst) : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + return %0 : tensor<2x3xf32> + + // CHECK: %[[RESULT:.*]] = "tfl.relu"(%arg0) + // CHECK: return %[[RESULT]] +} // CHECK-LABEL: Relu1 func @Relu1(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { @@ -992,3 +1027,91 @@ func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: return %arg0 } +func @squaredDifferenceReluRemoveRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tfl.relu"(%0) : (tensor<1xf32>) -> tensor<1xf32> + return %1: tensor<1xf32> + +// CHECK-LABEL: squaredDifferenceReluRemoveRelu +// CHECK: %[[RESULT:.*]] = tfl.squared_difference %arg0, %arg1 : tensor<1xf32> +// CHECK: return %[[RESULT]] +} + +func @ConvertSqueezeToReshapeWithDynamicDimension(%arg0: tensor) -> tensor { + %0 = "tfl.squeeze"(%arg0) {squeeze_dims = [1]}: (tensor) -> tensor + return %0: tensor + +// CHECK-LABEL: ConvertSqueezeToReshapeWithDynamicDimension +// CHECK: [[CONST:.*]] = constant dense<[-1, 8, 3]> : tensor<3xi32> +// CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor, tensor<3xi32>) -> tensor +// CHECK: return %[[RESULT]] +} + +func @ConvertSqueezeToReshapeWithDynamicDimension2(%arg0: tensor) -> tensor<1x8x3xf32> { + %0 = "tfl.squeeze"(%arg0) {squeeze_dims = [0]}: (tensor) -> tensor<1x8x3xf32> + return %0: tensor<1x8x3xf32> + +// CHECK-LABEL: ConvertSqueezeToReshapeWithDynamicDimension2 +// CHECK: [[CONST:.*]] = constant dense<[1, 8, 3]> : tensor<3xi32> +// CHECK: %[[RESULT:.*]] = "tfl.reshape"(%arg0, %[[CONST:.*]]) : (tensor, tensor<3xi32>) -> tensor<1x8x3xf32> +// CHECK: return %[[RESULT]] +} + +func @DontConvertSqueezeToReshape(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tfl.squeeze"(%arg0) {squeeze_dims = [0]}: (tensor<*xf32>) -> tensor<*xf32> + return %0: tensor<*xf32> + +// CHECK-LABEL: DontConvertSqueezeToReshape +// CHECK: %[[RESULT:.*]] = "tfl.squeeze"(%arg0) +// CHECK: return %[[RESULT]] +} + +func @ConvertPow1ToIdentity(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<1.000000e+00> : tensor + %0 = "tfl.pow"(%arg0, %cst) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + +// CHECK-LABEL: ConvertPow1ToIdentity +// CHECK: return %arg0 +} + +func @ConvertPow2ToSquare(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = constant dense<2.000000e+00> : tensor + %0 = "tfl.pow"(%arg0, %cst) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + +// CHECK-LABEL: ConvertPow2ToSquare +// CHECK: %[[RESULT:.*]] = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2x2xf32> +// CHECK: return %[[RESULT]] +} + +func @ConvertIdentityGatherNdOp(%arg0: tensor<4x3xf32>) -> tensor<4x3xf32> { + %cst = constant dense<[[0], [1], [2], [3]]> : tensor<4x1xi32> + %0 = "tfl.gather_nd"(%arg0, %cst) : (tensor<4x3xf32>, tensor<4x1xi32>) -> tensor<4x3xf32> + return %0 : tensor<4x3xf32> + +// CHECK-LABEL: ConvertIdentityGatherNdOp +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x3xf32>) -> tensor<4x3xf32> +// CHECK-NEXT: return %[[ARG]] : tensor<4x3xf32> +} + +func @ConvertIdentityGatherNdOp3D(%arg0: tensor<4x3x4xf32>) -> tensor<4x3x4xf32> { + %cst = constant dense<[[0], [1], [2], [3]]> : tensor<4x1xi32> + %0 = "tfl.gather_nd"(%arg0, %cst) : (tensor<4x3x4xf32>, tensor<4x1xi32>) -> tensor<4x3x4xf32> + return %0 : tensor<4x3x4xf32> + +// CHECK-LABEL: ConvertIdentityGatherNdOp3D +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x3x4xf32>) -> tensor<4x3x4xf32> +// CHECK-NEXT: return %[[ARG]] : tensor<4x3x4xf32> +} + +func @ConvertIdentityScatterNd(%arg0: tensor<4x3xf32>) -> tensor<4x3xf32> { + %cst = constant dense<[[0], [1], [2], [3]]> : tensor<4x1xi32> + %shape = constant dense<[4, 3]> : tensor<2xi32> + %0 = "tfl.scatter_nd"(%cst, %arg0, %shape) : (tensor<4x1xi32>, tensor<4x3xf32>, tensor<2xi32>) -> tensor<4x3xf32> + return %0 : tensor<4x3xf32> + +// CHECK-LABEL: ConvertIdentityScatterNd +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x3xf32>) -> tensor<4x3xf32> +// CHECK-NEXT: return %[[ARG]] : tensor<4x3xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir index 7ce60d98062..6847cdd5874 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-composite-functions-tf.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file -verify-diagnostics | FileCheck %s +// RUN: tf-opt -tfl-prepare-composite-funcs-tf %s -split-input-file -verify-diagnostics | FILECHECK_OPTS="" FileCheck %s module{ func @embedding(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> attributes {tf._implements = "embedding_matmul", tf._reference = "mlir"} { @@ -154,18 +154,18 @@ func @layernormalizedlstmcellsimple(%arg0: tensor<1x?xf32>, %arg1: tensor<3x4xf3 // ----- module { -func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor %1 = "tf.Add"(%0, %arg5) : (tensor, tensor<40xf32>) -> tensor %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor, tensor<10x40xf32>) -> tensor - %3 = "tf.Add"(%2, %arg1) : (tensor, tensor) -> tensor - %4 = "tf.Add"(%2, %arg2) : (tensor, tensor) -> tensor - %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor + %3 = "tf.Add"(%2, %arg1) : (tensor, tensor<8x10xf32>) -> tensor + %4 = "tf.Add"(%2, %arg2) : (tensor, tensor<8x10xf32>) -> tensor + %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor - return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor + return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor } -// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +// CHECK: func @inference_standard_lstm_time_major([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { // CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi32> // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> @@ -180,33 +180,33 @@ func @inference_standard_lstm_time_major(%arg0: tensor, %arg1: tensor // CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = constant unit -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK: [[VAL_21:%.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK: [[VAL_22:%.*]] = constant dense<0> : tensor<3xi32> // CHECK: [[VAL_23:%.*]] = constant dense<1> : tensor<3xi32> // CHECK: [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> -// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> // CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor, tensor, tensor, tensor +// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor // CHECK: } } // ----- module { -func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { +func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32> %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32> - %3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor) -> tensor<8x8x10xf32> - %4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor) -> tensor<8x8x10xf32> - %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor + %3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor - return %5, %4, %5, %5, %6 : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor + return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor } -// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { +// CHECK: func @inference_standard_lstm_non_time_major([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = false} { // CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi32> // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> @@ -221,15 +221,15 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te // CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = constant unit -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> // CHECK: [[VAL_21:%.*]] = constant dense<[0, -1, 0]> : tensor<3xi32> // CHECK: [[VAL_22:%.*]] = constant dense<0> : tensor<3xi32> // CHECK: [[VAL_23:%.*]] = constant dense<1> : tensor<3xi32> // CHECK: [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> -// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> // CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor, tensor, tensor +// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor // CHECK: } } @@ -237,18 +237,18 @@ func @inference_standard_lstm_non_time_major(%arg0: tensor<8x8x8xf32>, %arg1: te // ----- module { -func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { +func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor %1 = "tf.Add"(%0, %arg5) : (tensor, tensor<40xf32>) -> tensor %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor, tensor<10x40xf32>) -> tensor - %3 = "tf.Add"(%2, %arg1) : (tensor, tensor) -> tensor - %4 = "tf.Add"(%2, %arg2) : (tensor, tensor) -> tensor - %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor + %3 = "tf.Add"(%2, %arg1) : (tensor, tensor<8x10xf32>) -> tensor + %4 = "tf.Add"(%2, %arg2) : (tensor, tensor<8x10xf32>) -> tensor + %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor - return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor + return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor } -// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { +// CHECK: func @inference_standard_lstm_time_major_go_backwards([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = true} { // CHECK: [[VAL_6:%.*]] = constant dense<0> : tensor<1xi32> // CHECK: [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor, tensor<1xi32>) -> tensor // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> @@ -265,15 +265,15 @@ func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, // CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_21:%.*]] = constant unit -// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor +// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK: [[VAL_23:%.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK: [[VAL_24:%.*]] = constant dense<0> : tensor<3xi32> // CHECK: [[VAL_25:%.*]] = constant dense<1> : tensor<3xi32> // CHECK: [[VAL_26:%.*]] = "tf.StridedSlice"([[VAL_22]], [[VAL_23]], [[VAL_24]], [[VAL_25]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> -// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> // CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_26]], [[VAL_22]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<8x10xf32>, tensor, tensor, tensor, tensor +// CHECK: return [[VAL_26]], [[VAL_22]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor // CHECK: } } @@ -281,18 +281,18 @@ func @inference_standard_lstm_time_major_go_backwards(%arg0: tensor, // ----- module { -func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { +func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor<8x8x8xf32>, tensor<8x40xf32>) -> tensor<8x8x40xf32> %1 = "tf.Add"(%0, %arg5) : (tensor<8x8x40xf32>, tensor<40xf32>) -> tensor<8x8x40xf32> %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor<8x8x40xf32>, tensor<10x40xf32>) -> tensor<8x8x10xf32> - %3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor) -> tensor<8x8x10xf32> - %4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor) -> tensor<8x8x10xf32> - %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor + %3 = "tf.Add"(%2, %arg1) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %4 = "tf.Add"(%2, %arg2) : (tensor<8x8x10xf32>, tensor<8x10xf32>) -> tensor<8x8x10xf32> + %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor - return %5, %4, %5, %5, %6 : tensor, tensor<8x8x10xf32>, tensor, tensor, tensor + return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor } -// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { +// CHECK: func @inference_standard_lstm_non_time_major_go_backwards([[VAL_0:%.*]]: tensor<8x8x8xf32>, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = true, tf.time_major = false} { // CHECK: [[VAL_6:%.*]] = constant dense<1> : tensor<1xi32> // CHECK: [[VAL_7:%.*]] = "tf.ReverseV2"([[VAL_0]], [[VAL_6]]) : (tensor<8x8x8xf32>, tensor<1xi32>) -> tensor<8x8x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> @@ -309,15 +309,15 @@ func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf3 // CHECK: [[VAL_19:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_20:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_18]], [[VAL_19]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_21:%.*]] = constant unit -// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor<8x8x10xf32> +// CHECK: [[VAL_22:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_7]], [[VAL_14]]#0, [[VAL_14]]#1, [[VAL_14]]#2, [[VAL_14]]#3, [[VAL_17]]#0, [[VAL_17]]#1, [[VAL_17]]#2, [[VAL_17]]#3, [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_20]]#0, [[VAL_20]]#1, [[VAL_20]]#2, [[VAL_20]]#3, [[VAL_21]], [[VAL_21]], [[VAL_1]], [[VAL_2]], [[VAL_21]], [[VAL_21]], [[VAL_21]], [[VAL_21]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false} : (tensor<8x8x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor<8x8x10xf32> // CHECK: [[VAL_23:%.*]] = constant dense<[0, -1, 0]> : tensor<3xi32> // CHECK: [[VAL_24:%.*]] = constant dense<0> : tensor<3xi32> // CHECK: [[VAL_25:%.*]] = constant dense<1> : tensor<3xi32> // CHECK: [[VAL_26:%.*]] = "tf.StridedSlice"([[VAL_22]], [[VAL_23]], [[VAL_24]], [[VAL_25]]) {begin_mask = 5 : i64, ellipsis_mask = 0 : i64, end_mask = 5 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 2 : i64} : (tensor<8x8x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> -// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK: [[VAL_28:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> // CHECK: [[VAL_29:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_26]], [[VAL_22]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor, tensor, tensor +// CHECK: return [[VAL_26]], [[VAL_22]], [[VAL_27]], [[VAL_28]], [[VAL_29]] : tensor<8x10xf32>, tensor<8x8x10xf32>, tensor<8x10xf32>, tensor<8x10xf32>, tensor // CHECK: } } @@ -325,25 +325,25 @@ func @inference_standard_lstm_non_time_major_go_backwards(%arg0: tensor<8x8x8xf3 // ----- module { -func @inference_can_fuse(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) { +func @inference_can_fuse(%arg0: tensor, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) { %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor} : () -> tensor - %1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_can_fuse} : (tensor, tensor, tensor, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) + %1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_can_fuse} : (tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) %2 = "tf.Add"(%0, %1#1) : (tensor, tensor) -> tensor return } -func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor %1 = "tf.Add"(%0, %arg5) : (tensor, tensor<40xf32>) -> tensor %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor, tensor<10x40xf32>) -> tensor - %3 = "tf.Add"(%2, %arg1) : (tensor, tensor) -> tensor - %4 = "tf.Add"(%2, %arg2) : (tensor, tensor) -> tensor - %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor + %3 = "tf.Add"(%2, %arg1) : (tensor, tensor<8x10xf32>) -> tensor + %4 = "tf.Add"(%2, %arg2) : (tensor, tensor<8x10xf32>) -> tensor + %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor - return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor + return %5, %4, %5, %5, %6 : tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor } -// CHECK: func @inference_standard_lstm_time_major_can_fuse([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +// CHECK: func @inference_standard_lstm_time_major_can_fuse([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { // CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi32> // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> @@ -358,15 +358,15 @@ func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor, %arg // CHECK: [[VAL_17:%.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = constant unit -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK: [[VAL_21:%.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK: [[VAL_22:%.*]] = constant dense<0> : tensor<3xi32> // CHECK: [[VAL_23:%.*]] = constant dense<1> : tensor<3xi32> // CHECK: [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> -// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> // CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor, tensor, tensor, tensor +// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor // CHECK: } } @@ -374,26 +374,26 @@ func @inference_standard_lstm_time_major_can_fuse(%arg0: tensor, %arg // ----- module { -func @inference_can_fuse_last_output(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) { +func @inference_can_fuse_last_output(%arg0: tensor, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) { %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = f32, value = dense<0.000000e+00> : tensor} : () -> tensor - %1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_can_fuse_last_output} : (tensor, tensor, tensor, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) + %1:5 = "tf.PartitionedCall"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {Tin = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT", "tfdtype$DT_FLOAT"], _output_shapes = ["tfshape$dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 9 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$"], _read_only_resource_inputs = [], config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\008\01", device = "", executor_type = "", f = @inference_standard_lstm_time_major_can_fuse_last_output} : (tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor<8x40xf32>, tensor<10x40xf32>, tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) %2 = "tf.Add"(%0, %1#0) : (tensor, tensor<8x10xf32>) -> tensor<8x10xf32> return } -func @inference_standard_lstm_time_major_can_fuse_last_output(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +func @inference_standard_lstm_time_major_can_fuse_last_output(%arg0: tensor, %arg1: tensor<8x10xf32>, %arg2: tensor<8x10xf32>, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor %1 = "tf.Add"(%0, %arg5) : (tensor, tensor<40xf32>) -> tensor %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor, tensor<10x40xf32>) -> tensor - %3 = "tf.Add"(%2, %arg1) : (tensor, tensor) -> tensor - %4 = "tf.Add"(%2, %arg2) : (tensor, tensor) -> tensor - %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor + %3 = "tf.Add"(%2, %arg1) : (tensor, tensor<8x10xf32>) -> tensor + %4 = "tf.Add"(%2, %arg2) : (tensor, tensor<8x10xf32>) -> tensor + %5 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor - %7 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor<8x10xf32> - return %7, %4, %5, %5, %6 : tensor<8x10xf32>, tensor, tensor, tensor, tensor + %7 = "tf.Add"(%arg1, %arg2) : (tensor<8x10xf32>, tensor<8x10xf32>) -> tensor<8x10xf32> + return %7, %4, %5, %5, %6 : tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor } -// CHECK: func @inference_standard_lstm_time_major_can_fuse_last_output([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor, [[VAL_2:%.*]]: tensor, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +// CHECK: func @inference_standard_lstm_time_major_can_fuse_last_output([[VAL_0:%.*]]: tensor, [[VAL_1:%.*]]: tensor<8x10xf32>, [[VAL_2:%.*]]: tensor<8x10xf32>, [[VAL_3:%.*]]: tensor<8x40xf32>, [[VAL_4:%.*]]: tensor<10x40xf32>, [[VAL_5:%.*]]: tensor<40xf32>) -> (tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$dim { size: 8 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: false", "tfshape$unknown_rank: false"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { // CHECK: [[VAL_6:%.*]] = constant dense<[1, 0]> : tensor<2xi32> // CHECK: [[VAL_7:%.*]] = "tf.Transpose"([[VAL_3]], [[VAL_6]]) : (tensor<8x40xf32>, tensor<2xi32>) -> tensor<40x8xf32> // CHECK: [[VAL_8:%.*]] = constant dense<[1, 0]> : tensor<2xi32> @@ -408,15 +408,15 @@ func @inference_standard_lstm_time_major_can_fuse_last_output(%arg0: tensor : tensor} : () -> tensor // CHECK: [[VAL_18:%.*]]:4 = "tf.SplitV"([[VAL_5]], [[VAL_16]], [[VAL_17]]) : (tensor<40xf32>, tensor<4xi32>, tensor) -> (tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>) // CHECK: [[VAL_19:%.*]] = constant unit -// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor, tensor, none, none, none, none) -> tensor +// CHECK: [[VAL_20:%.*]] = "tfl.unidirectional_sequence_lstm"([[VAL_0]], [[VAL_12]]#0, [[VAL_12]]#1, [[VAL_12]]#2, [[VAL_12]]#3, [[VAL_15]]#0, [[VAL_15]]#1, [[VAL_15]]#2, [[VAL_15]]#3, [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_18]]#0, [[VAL_18]]#1, [[VAL_18]]#2, [[VAL_18]]#3, [[VAL_19]], [[VAL_19]], [[VAL_1]], [[VAL_2]], [[VAL_19]], [[VAL_19]], [[VAL_19]], [[VAL_19]]) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = true} : (tensor, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x8xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>, none, none, none, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, tensor<10xf32>, none, none, tensor<8x10xf32>, tensor<8x10xf32>, none, none, none, none) -> tensor // CHECK: [[VAL_21:%.*]] = constant dense<[-1, 0, 0]> : tensor<3xi32> // CHECK: [[VAL_22:%.*]] = constant dense<0> : tensor<3xi32> // CHECK: [[VAL_23:%.*]] = constant dense<1> : tensor<3xi32> // CHECK: [[VAL_24:%.*]] = "tf.StridedSlice"([[VAL_20]], [[VAL_21]], [[VAL_22]], [[VAL_23]]) {begin_mask = 6 : i64, ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x10xf32> -// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor +// CHECK: [[VAL_25:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> +// CHECK: [[VAL_26:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor<8x10xf32> // CHECK: [[VAL_27:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<1xf32>} : () -> tensor -// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor, tensor, tensor, tensor +// CHECK: return [[VAL_24]], [[VAL_20]], [[VAL_25]], [[VAL_26]], [[VAL_27]] : tensor<8x10xf32>, tensor, tensor<8x10xf32>, tensor<8x10xf32>, tensor // CHECK: } } @@ -456,6 +456,32 @@ func @inference_standard_lstm_time_major_cannot_fuse(%arg0: tensor, % // ----- +module { +func @dynamic_shape_non_fuse_standard_lstm(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<8x40xf32>, %arg4: tensor<10x40xf32>, %arg5: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { + %0 = "tf.BatchMatMulV2"(%arg0, %arg3) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor + %1 = "tf.Add"(%0, %arg5) : (tensor, tensor<40xf32>) -> tensor + %2 = "tf.BatchMatMulV2"(%1, %arg4) {adj_x = false, adj_y = true} : (tensor, tensor<10x40xf32>) -> tensor + %3 = "tf.Add"(%2, %arg1) : (tensor, tensor) -> tensor + %4 = "tf.Add"(%2, %arg2) : (tensor, tensor) -> tensor + %5 = "tf.Add"(%arg1, %arg2) : (tensor, tensor) -> tensor + %6 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor + return %5, %4, %5, %5, %6 : tensor, tensor, tensor, tensor, tensor +} + +// CHECK: func @dynamic_shape_non_fuse_standard_lstm(%[[VAL_0:.*]]: tensor, %[[VAL_1:.*]]: tensor, %[[VAL_2:.*]]: tensor, %[[VAL_3:.*]]: tensor<8x40xf32>, %[[VAL_4:.*]]: tensor<10x40xf32>, %[[VAL_5:.*]]: tensor<40xf32>) -> (tensor, tensor, tensor, tensor, tensor) attributes {tf._input_shapes = ["tfshape$dim { size: -1 } dim { size: 8 } dim { size: 8 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$dim { size: -1 } dim { size: 10 }", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true", "tfshape$unknown_rank: true"], tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c", tf.api_preferred_device = "CPU", tf.go_backwards = false, tf.time_major = true} { +// CHECK: %[[VAL_6:.*]] = "tf.BatchMatMulV2"(%[[VAL_0]], %[[VAL_3]]) {adj_x = false, adj_y = false} : (tensor, tensor<8x40xf32>) -> tensor +// CHECK: %[[VAL_7:.*]] = "tf.Add"(%[[VAL_6]], %[[VAL_5]]) : (tensor, tensor<40xf32>) -> tensor +// CHECK: %[[VAL_8:.*]] = "tf.BatchMatMulV2"(%[[VAL_7]], %[[VAL_4]]) {adj_x = false, adj_y = true} : (tensor, tensor<10x40xf32>) -> tensor +// CHECK: %[[VAL_9:.*]] = "tf.Add"(%[[VAL_8]], %[[VAL_1]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_10:.*]] = "tf.Add"(%[[VAL_8]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = "tf.Add"(%[[VAL_1]], %[[VAL_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_12:.*]] = "tf.Const"() {_output_shapes = ["tfshape$"], device = "/device:CPU:0", dtype = f32, value = dense<1.000000e+00> : tensor} : () -> tensor +// CHECK: return %[[VAL_11]], %[[VAL_10]], %[[VAL_11]], %[[VAL_11]], %[[VAL_12]] : tensor, tensor, tensor, tensor, tensor +// CHECK: } +} + +// ----- + module { func @nms_padded(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> (tensor<1x10xi32>, tensor) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} { %0 = "tf.Const"() {value = dense<1> : tensor<1x10xi32>} : () -> tensor<1x10xi32> @@ -481,3 +507,15 @@ func @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf3 // expected-error @+1 {{TFLite does not support batched input for non_max_suppression_padded}} func @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> (tensor<2x10xi32>, tensor) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"} } + +// ----- + +module { +// CHECK-LABEL: func @some_func +// CHECK-LABEL: func @func_with_call +func @some_func(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes {tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c"} +func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> { + %0 = call @some_func(%arg0) : (tensor<100xf32>) -> tensor<100xf32> + return %0 : tensor<100xf32> + } +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 7b51ec32f89..066139e179b 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -1,5 +1,7 @@ // RUN: tf-opt -tfl-prepare-tf %s | FileCheck %s +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func @conv(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<256x3x32x32xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x16x30x30xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) { ^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<256x3x32x32xf32>) : // OK @@ -578,3 +580,19 @@ func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> // CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32> // CHECK: return %[[RES]] } + +// CHECK-LABEL: xla_conv +func @xla_conv(%arg0: tensor<4x8x8x16xf32>) -> tensor<4x8x8x16xf32> { + %0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<3x3x16x16xf32>} : () -> tensor<3x3x16x16xf32> loc("Const_1") + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor loc("XlaConv/feature_group_count") + %2 = "tf.Const"() {value = dense<1> : tensor<2x2xi32>} : () -> tensor<2x2xi32> loc("XlaConv/padding") + %3 = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> tensor<2xi32> loc("XlaConv/window_strides") + %4 = "tf.XlaConv"(%arg0, %0, %3, %2, %3, %3, %1) {device = "", dimension_numbers = "\18\02 \032\02\00\01@\03P\03Z\02\01\02b\02\01\02", precision_config = ""} : (tensor<4x8x8x16xf32>, tensor<3x3x16x16xf32>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor) -> tensor<4x8x8x16xf32> + return %4 : tensor<4x8x8x16xf32> + // CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<16xf32> + // CHECK: %[[CST0:.*]] = constant dense<1.000000e+00> : tensor<16x3x3x16xf32> + // CHECK: %[[RES:.*]] = "tfl.conv_2d"(%arg0, %[[CST0]], %[[CST]]) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<4x8x8x16xf32>, tensor<16x3x3x16xf32>, tensor<16xf32>) -> tensor<4x8x8x16xf32> + // CHECK: return %[[RES]] +} + +} diff --git a/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir b/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir new file mode 100644 index 00000000000..1bac8019a30 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir @@ -0,0 +1,20 @@ +// RUN: tf-opt -tfl-raise-custom-ops -canonicalize %s -o - | FileCheck %s + +// CHECK-LABEL: custom_op +func @custom_op(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> + %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // will be preserved since it has uses. + %2 = "tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // will be removed since it doesn't have uses and doesn't have side effect. + "tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %2 : tensor<4xf32> + +// CHECK-NEXT: %[[CST:.*]] = constant dense<1.000000e+00> +// CHECK-NEXT: %[[MUL:.*]] = tfl.mul %arg0, %[[CST]] {fused_activation_function = "NONE"} : tensor<4xf32> +// CHECK-NEXT: %[[CUSTOM:.*]] = "tfl.custom_tf"(%[[MUL]], %[[CST]]) ( { +// CHECK-NEXT: %[[MY_CUSTOM:.*]] = "tf.MyCustomOp"(%[[MUL]], %[[CST]]) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: "tfl.yield"(%[[MY_CUSTOM]]) : (tensor<4xf32>) -> () +// CHECK-NEXT: }) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +// CHECK-NEXT: return %[[CUSTOM]] : tensor<4xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 1e1c431822d..d63eb481376 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -166,6 +166,10 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, // The below passes only make sense if Builtin TFLite ops are enabled // for emission. if (pass_config.emit_builtin_tflite_ops) { + // Run shape inference after variables are converted to constants. + if (pass_config.shape_inference) { + pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); + } // Prepare for TFLite dialect, rerun canonicalization, and then legalize to // the TFLite dialect. pass_manager->addPass( @@ -173,8 +177,19 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, pass_manager->addNestedPass(mlir::createCanonicalizerPass()); if (pass_config.shape_inference) { // Add a shape inference pass to optimize away the unnecessary casts. + // This also fixes the unranked shapes due to TF ops constant folding. + // TODO(fengliuai): remove this pass if TableGen patterns have a better + // to control the shapes for the intermediate results. pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); } + + // Inline function calls that left in the graph after folding functional + // control flow ops (IfOp, CaseOp). + pass_manager->addPass(mlir::createInlinerPass()); + + // This pass removes the asset file dependencies in hash table use cases. + pass_manager->addPass(mlir::TF::CreateInitTextFileToImportPass()); + pass_manager->addPass( mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification)); pass_manager->addPass(mlir::TFL::CreateOptimizePass()); @@ -182,6 +197,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, // so that it can target constants introduced once TensorFlow Identity ops // are removed during legalization. pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass()); + pass_manager->addPass(mlir::TFL::CreateRaiseCustomOpsPass()); pass_manager->addPass(mlir::createSymbolDCEPass()); pass_manager->addNestedPass(mlir::createCanonicalizerPass()); pass_manager->addNestedPass(mlir::createCSEPass()); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 963ab743a83..046c7bbbcf0 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -144,6 +144,10 @@ int main(int argc, char **argv) { StatusOr module; + tensorflow::GraphImportConfig specs; + specs.upgrade_legacy = upgrade_legacy; + specs.prune_unused_nodes = true; + // TODO(b/147435528): We need to test the e2e behavior once the graph freezing // inside mlir is done. if (import_saved_model_object_graph || import_saved_model_signature_defs) { @@ -168,12 +172,10 @@ int main(int argc, char **argv) { return kTrFailure; } - module = tensorflow::ImportSavedModel(input_file_name, saved_model_version, - tags, exported_names, &context); + module = + tensorflow::ImportSavedModel(input_file_name, saved_model_version, tags, + exported_names, specs, &context); } else { - tensorflow::GraphImportConfig specs; - specs.upgrade_legacy = upgrade_legacy; - specs.prune_unused_nodes = true; module = tensorflow::LoadFromGraphdefOrMlirSource( input_file_name, input_mlir, use_splatted_constant, custom_opdefs, specs, debug_info_file, input_arrays, input_dtypes, input_shapes, diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 714bc493bed..c158f3a8e21 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -129,6 +129,18 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( bool emit_select_tf_ops, bool emit_custom_ops, const mlir::TFL::QuantizationSpecs& quant_specs, std::string* result, mlir::PassManager* pass_manager) { + // Register a warning handler only log to std out. + mlir::ScopedDiagnosticHandler s( + module.getContext(), [](mlir::Diagnostic& diag) { + if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) { + for (auto& note : diag.getNotes()) { + std::cout << note.str() << "\n"; + LOG(WARNING) << note.str() << "\n"; + } + } + return mlir::failure(); + }); + mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), /*propagate=*/true); @@ -186,7 +198,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( StatusOr ImportSavedModel( const std::string& input_filename, const int saved_model_version, const std::unordered_set& tags, - absl::Span exported_names, mlir::MLIRContext* context) { + absl::Span exported_names, const GraphImportConfig& specs, + mlir::MLIRContext* context) { if (saved_model_version == 2) { auto module_or = tensorflow::SavedModelObjectGraphToMlirImport( input_filename, tags, exported_names, context); @@ -194,7 +207,7 @@ StatusOr ImportSavedModel( return module_or.ConsumeValueOrDie(); } else if (saved_model_version == 1) { auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, exported_names, context); + input_filename, tags, exported_names, context, specs.upgrade_legacy); if (!module_or.status().ok()) return module_or.status(); return module_or.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 4ad58c4f8ef..8f1edec8879 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -48,7 +48,8 @@ LoadFromGraphdefOrMlirSource( stream_executor::port::StatusOr ImportSavedModel( const std::string& input_filename, const int saved_model_version, const std::unordered_set& tags, - absl::Span exported_names, mlir::MLIRContext* context); + absl::Span exported_names, const GraphImportConfig& specs, + mlir::MLIRContext* context); // Taking a MLIR module in TF executor dialect and a set of parameters, // applies a set of passes to convert the module to TF Lite dialect and diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc index f5ef2585be5..50a7ee52430 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc @@ -16,6 +16,7 @@ limitations under the License. // This transformation pass convert dense tensor to sparse format. #include "absl/memory/memory.h" +#include "third_party/eigen3/Eigen/Core" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -36,6 +37,16 @@ const float kMinSparsityLevel = 0.3; // Heuristic to check if a block configuration is correct. const float kBlockOverRandomSparsityRatio = 0.9; +Eigen::half APFloatToEigenHalf(const APFloat& val) { + uint16_t raw_data = val.bitcastToAPInt().getZExtValue(); + return Eigen::half_impl::raw_uint16_to_half(raw_data); +} + +APFloat EigenHalfToAPFloat(const Eigen::half& val) { + uint16_t raw_data = val.x; + return APFloat(APFloat::IEEEhalf(), APInt(16, raw_data)); +} + void PopulateEncodingParams(const std::vector& block_size, std::vector* traversal_order, std::vector* format, @@ -64,14 +75,18 @@ void PopulateEncodingParams(const std::vector& block_size, } } +inline float GetSparsity(const int num_zeros, const int num_elements) { + return (1.0 * num_zeros / num_elements); +} + float CalculateRandomSparsity(const ElementsAttr& attr, const ShapedType& type) { int num_elements = type.getNumElements(); int num_zeros = 0; - if (type.getElementType().isF32()) { - for (const auto val : attr.getValues()) { - if (val == 0.f) { + if (type.getElementType().isa()) { + for (const auto val : attr.getValues()) { + if (val.isZero()) { num_zeros++; } } @@ -83,7 +98,7 @@ float CalculateRandomSparsity(const ElementsAttr& attr, } } - return 1.0 * num_zeros / num_elements; + return GetSparsity(num_zeros, num_elements); } float CalculateBlockSparsity(const ElementsAttr& attr, const ShapedType& type, @@ -108,7 +123,19 @@ float CalculateBlockSparsity(const ElementsAttr& attr, const ShapedType& type, for (const auto val : attr.getValues()) data.push_back(val); format_converter.DenseToSparse(data.data()); sparsity = - 1 - 1.0 * format_converter.GetData().size() / type.getNumElements(); + GetSparsity(type.getNumElements() - format_converter.GetData().size(), + type.getNumElements()); + } else if (type.getElementType().isF16()) { + tflite::optimize::sparsity::FormatConverter format_converter( + shape, traversal_order, format, b_size, b_map); + std::vector data; + data.reserve(type.getNumElements()); + for (const auto& val : attr.getValues()) + data.push_back(APFloatToEigenHalf(val)); + format_converter.DenseToSparse(data.data()); + sparsity = + GetSparsity(type.getNumElements() - format_converter.GetData().size(), + type.getNumElements()); } else if (type.getElementType().isa()) { tflite::optimize::sparsity::FormatConverter format_converter( shape, traversal_order, format, b_size, b_map); @@ -117,7 +144,8 @@ float CalculateBlockSparsity(const ElementsAttr& attr, const ShapedType& type, for (const auto val : attr.getValues()) data.push_back(val); format_converter.DenseToSparse(data.data()); sparsity = - 1 - 1.0 * format_converter.GetData().size() / type.getNumElements(); + GetSparsity(type.getNumElements() - format_converter.GetData().size(), + type.getNumElements()); } return sparsity; @@ -184,8 +212,8 @@ InspectResult InspectWeight( template std::vector BuildSparsityParameterAttribute( - const std::vector& block_size, Operation* inst, OpBuilder* builder, - SparsityParameterAttr* s_param) { + const std::vector& block_size, const T* dense_buffer, Operation* inst, + OpBuilder* builder, SparsityParameterAttr* s_param) { ElementsAttr attr; ShapedType type; if (auto cst = dyn_cast(inst)) { @@ -210,10 +238,7 @@ std::vector BuildSparsityParameterAttribute( tflite::optimize::sparsity::FormatConverter format_converter( shape, traversal_order, format, b_size, b_map); - std::vector data; - data.reserve(type.getNumElements()); - for (const auto val : attr.getValues()) data.push_back(val); - format_converter.DenseToSparse(data.data()); + format_converter.DenseToSparse(dense_buffer); auto metadata = format_converter.GetDimMetadata(); auto compressed_data = format_converter.GetData(); const int dim_size = metadata.size() / 2; @@ -264,15 +289,28 @@ void DenseToSparse::runOnFunction() { func.walk([&](SparseOpInterface sparse_op) { const auto& sparse_operands = sparse_op.GetSparseOperands(); std::vector> supported_block_size; - for (const int operand : sparse_operands) { + for (int operand : sparse_operands) { auto* op = sparse_op.getOperation(); - const auto& value = op->getOperand(operand); + auto value = op->getOperand(operand); auto* inst = value.getDefiningOp(); if (!inst) { continue; } + // There could be a Dequantize op after the weight tensor in cases like + // fp16 post-training quantization. We need to get the weight from the + // input of the Dequantize op. + if (isa(inst)) { + op = inst; + value = inst->getOperand(0); + inst = value.getDefiningOp(); + if (!inst) { + continue; + } + operand = 0; + } + ShapedType type; if (isa(inst)) { supported_block_size = sparse_op.GetFloatBlockSize(); @@ -297,22 +335,60 @@ void DenseToSparse::runOnFunction() { builder.setInsertionPoint(op); SparsityParameterAttr s_param; if (auto cst = dyn_cast(inst)) { - std::vector compressed_data = - BuildSparsityParameterAttribute(result.selected_block_size, - inst, &builder, &s_param); - auto compressed_data_type = RankedTensorType::get( - {static_cast(compressed_data.size())}, - builder.getF32Type()); - auto new_value = DenseElementsAttr::get(compressed_data_type, - compressed_data); - auto s_const = builder.create(op->getLoc(), cst.value(), - s_param, new_value); - value.replaceAllUsesWith(s_const.getResult()); - cst.erase(); + auto attr = cst.value(); + auto type = cst.getType().cast(); + if (type.getElementType().isF32()) { + std::vector dense_data; + dense_data.reserve(type.getNumElements()); + for (const auto val : attr.getValues()) + dense_data.push_back(val); + std::vector compressed_data = + BuildSparsityParameterAttribute(result.selected_block_size, + dense_data.data(), inst, + &builder, &s_param); + auto compressed_data_type = RankedTensorType::get( + {static_cast(compressed_data.size())}, + builder.getF32Type()); + auto new_value = DenseElementsAttr::get(compressed_data_type, + compressed_data); + auto s_const = builder.create( + op->getLoc(), cst.value(), s_param, new_value); + value.replaceAllUsesWith(s_const.getResult()); + cst.erase(); + } else if (type.getElementType().isF16()) { + std::vector dense_data; + dense_data.reserve(type.getNumElements()); + for (const auto& val : attr.getValues()) + dense_data.push_back(APFloatToEigenHalf(val)); + std::vector compressed_data = + BuildSparsityParameterAttribute( + result.selected_block_size, dense_data.data(), inst, &builder, + &s_param); + std::vector apfloat_data; + apfloat_data.reserve(type.getNumElements()); + for (const auto& val : compressed_data) + apfloat_data.push_back(EigenHalfToAPFloat(val)); + auto compressed_data_type = RankedTensorType::get( + {static_cast(compressed_data.size())}, + type.getElementType()); + auto new_value = + DenseElementsAttr::get(compressed_data_type, apfloat_data); + auto s_const = builder.create( + op->getLoc(), cst.value(), s_param, new_value); + value.replaceAllUsesWith(s_const.getResult()); + cst.erase(); + } } else if (auto cst = dyn_cast(inst)) { + auto attr = cst.value(); + auto type = cst.getType().cast(); + std::vector dense_data(type.getNumElements()); + dense_data.reserve(type.getNumElements()); + for (const auto& val : attr.getValues()) + dense_data.push_back(val); std::vector compressed_data = BuildSparsityParameterAttribute(result.selected_block_size, - inst, &builder, &s_param); + dense_data.data(), inst, + &builder, &s_param); auto compressed_data_type = RankedTensorType::get( {static_cast(compressed_data.size())}, builder.getIntegerType(8, true)); diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h index b745be7753a..2054bab4185 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -276,7 +276,7 @@ ConvertTFDilatedConvOp::ExtractDilationsAttrFromBlockShape( } // Check that the block_shape of `stb_op` and `bts_op` are equal. if (stb_bs_attr.getNumElements() != bts_bs_attr.getNumElements()) return {}; - for (uint64_t i = 0; i < stb_bs_attr.getNumElements(); ++i) { + for (uint64_t i = 0, end = stb_bs_attr.getNumElements(); i < end; ++i) { if (stb_bs_attr.getValue({i}) != bts_bs_attr.getValue({i})) return {}; } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index f5b45df3eee..47cfaecd3fb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -66,8 +66,10 @@ def LegalizeTFConstToTFLConst: Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; // Convert to std constant for statically shaped, non-opaque constants. -def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value), - [(AnyStaticShapeTensor $res)], (addBenefit 10)>; +def ConvertTfConstToStdConst : Pat< + (TF_ConstOp:$res NonOpaqueElementsAttr:$value), + (ConstantOp $value), + [(AnyStaticShapeTensor $res)], (addBenefit 10)>; //===----------------------------------------------------------------------===// // Unary ops patterns. @@ -162,186 +164,234 @@ def LegalizeMaximum : Pat<(TF_MaximumOp $arg1, $arg2), def LegalizeMinimum : Pat<(TF_MinimumOp $arg1, $arg2), (TFL_MinimumOp $arg1, $arg2)>; -def : Pat<(TF_NegOp $arg), (TFL_NegOp $arg)>; -def : Pat<(TF_OneHotOp $indices, $depth, $on_value, $off_value, $axis), - (TFL_OneHotOp $indices, $depth, $on_value, $off_value, - (convertIntAttrTo32Bit $axis))>; -def : Pat<(TF_PowOp $x, $y), (TFL_PowOp $x, $y)>; -def : Pat<(TF_RangeOp $start, $limit, $delta), (TFL_RangeOp $start, $limit, $delta)>; -def : Pat<(TF_Relu6Op $arg), (TFL_Relu6Op $arg)>; -def : Pat<(TF_ReluOp $arg), (TFL_ReluOp $arg)>; -def : Pat<(TF_ReverseSequenceOp $input, $seq_lengths, $seq_dim, $batch_dim), - (TFL_ReverseSequenceOp $input, $seq_lengths, - (convertIntAttrTo32Bit $seq_dim), - (convertIntAttrTo32Bit $batch_dim))>; -def : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>; -def : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>; -def : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>; -def : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>; -def : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids), (TFL_SegmentSumOp $data, $segment_ids)>; -def : Pat<(TF_SelectOp $cond, $x, $y), (TFL_SelectOp $cond, $x, $y)>; -def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectOp $cond, $x, $y), [(HasSameStaticShapes $src_op)]>; -def : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), (TFL_SelectV2Op $cond, $x, $y), [(HasNotSameStaticShapes $src_op)]>; -def : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>; -def : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>; -def : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>; -def : Pat<(TF_SliceOp $input, $begin, $size), (TFL_SliceOp $input, $begin, $size)>; -def : Pat<(TF_SoftmaxOp $arg), (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>; -def : Pat<(TF_SoftplusOp F32Tensor:$arg0), (TFL_LogOp (TFL_AddOp (TFL_ExpOp $arg0), (ConstantOp ConstantAttr, "1.0f">), TFL_AF_None))>; -def : Pat<(TF_SqueezeOp $arg, $squeeze_dims), (TFL_SqueezeOp $arg, $squeeze_dims)>; -def : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>; -def : Pat<(TF_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, $perm)>; -def : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>; -def : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>; +def LegalizeNeg : Pat<(TF_NegOp $arg), (TFL_NegOp $arg)>; +def LegalizeOneHot : Pat< + (TF_OneHotOp $indices, $depth, $on_value, $off_value, $axis), + (TFL_OneHotOp $indices, $depth, $on_value, $off_value, + (convertIntAttrTo32Bit $axis))>; +def LegalizePow : Pat<(TF_PowOp $x, $y), (TFL_PowOp $x, $y)>; +def LegalizeRange : Pat<(TF_RangeOp $start, $limit, $delta), + (TFL_RangeOp $start, $limit, $delta)>; +def LegalizeRelu6 : Pat<(TF_Relu6Op $arg), (TFL_Relu6Op $arg)>; +def LegalizeRelu : Pat<(TF_ReluOp $arg), (TFL_ReluOp $arg)>; +def LegalizeReverseSequence : Pat< + (TF_ReverseSequenceOp $input, $seq_lengths, $seq_dim, $batch_dim), + (TFL_ReverseSequenceOp $input, $seq_lengths, + (convertIntAttrTo32Bit $seq_dim), (convertIntAttrTo32Bit $batch_dim))>; +def LegalizeRound : Pat<(TF_RoundOp $arg), (TFL_RoundOp $arg)>; +def LegalizeRsqrt : Pat<(TF_RsqrtOp $arg), (TFL_RsqrtOp $arg)>; +def LegalizeSqrt : Pat<(TF_SqrtOp $arg), (TFL_SqrtOp $arg)>; +def LegalizeSquare : Pat<(TF_SquareOp $arg), (TFL_SquareOp $arg)>; +def LegalizeSegmentSum : Pat<(TF_SegmentSumOp $data, I32Tensor:$segment_ids), + (TFL_SegmentSumOp $data, $segment_ids)>; +def LegalizeSelect : Pat<(TF_SelectOp $cond, $x, $y), + (TFL_SelectOp $cond, $x, $y)>; +def LegalizeSelectV2SameStaticShape : Pat<(TF_SelectV2Op:$src_op $cond, $x, $y), + (TFL_SelectOp $cond, $x, $y), + [(HasSameStaticShapes $src_op)]>; +def LegalizeSelectV2NotSameStaticShape : Pat< + (TF_SelectV2Op:$src_op $cond, $x, $y), + (TFL_SelectV2Op $cond, $x, $y), + [(HasNotSameStaticShapes $src_op)]>; +def LegalizeShape : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>; +def LegalizeSigmoid : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>; +def LegalizeSin : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>; +def LegalizeSlice : Pat<(TF_SliceOp $input, $begin, $size), + (TFL_SliceOp $input, $begin, $size)>; +def LegalizeSoftmax : Pat<(TF_SoftmaxOp $arg), + (TFL_SoftmaxOp $arg, ConstF32Attr<"1.0">)>; +def LegalizeSoftPlus : Pat<(TF_SoftplusOp F32Tensor:$arg0), + (TFL_LogOp (TFL_AddOp (TFL_ExpOp $arg0), + (ConstantOp ConstantAttr, "1.0f">), + TFL_AF_None))>; +def LegalizeSqueeze : Pat<(TF_SqueezeOp $arg, $squeeze_dims), + (TFL_SqueezeOp $arg, $squeeze_dims)>; +def LegalizeTanh : Pat<(TF_TanhOp $arg), (TFL_TanhOp $arg)>; +def LegalizeTranspose : Pat<(TF_TransposeOp $arg, $perm), + (TFL_TransposeOp $arg, $perm)>; +def LegalizeWhere : Pat<(TF_WhereOp $arg), (TFL_WhereOp $arg)>; +def LegalizeZerosLike : Pat<(TF_ZerosLikeOp $arg), (TFL_ZerosLikeOp $arg)>; //===----------------------------------------------------------------------===// // Binary ops patterns. //===----------------------------------------------------------------------===// -def : Pat<(TF_LessOp $l, $r), (TFL_LessOp $l, $r)>; -def : Pat<(TF_GreaterOp $l, $r), (TFL_GreaterOp $l, $r)>; +def LegalizeLess : Pat<(TF_LessOp $l, $r), (TFL_LessOp $l, $r)>; +def LegalizeGreater : Pat<(TF_GreaterOp $l, $r), (TFL_GreaterOp $l, $r)>; -def : Pat<(TF_LessEqualOp $l, $r), (TFL_LessEqualOp $l, $r)>; -def : Pat<(TF_GreaterEqualOp $l, $r), (TFL_GreaterEqualOp $l, $r)>; +def LegalizeLessEqual : Pat<(TF_LessEqualOp $l, $r), (TFL_LessEqualOp $l, $r)>; +def LegalizeGreaterEqual : Pat<(TF_GreaterEqualOp $l, $r), + (TFL_GreaterEqualOp $l, $r)>; // Gather in TF -> Gather in TFL with axis=0 // The 'validate_indices' attribute is deprecated. -def : Pat<(TF_GatherOp $params, $indices, $ignored_validate_indices), - (TFL_GatherOp $params, $indices, ConstantAttr)>; +def LegalizeGather: Pat< + (TF_GatherOp $params, $indices, $ignored_validate_indices), + (TFL_GatherOp $params, $indices, ConstantAttr)>; -def : Pat<(TF_GatherNdOp $params, $indices), - (TFL_GatherNdOp $params, $indices)>; +def LegalizeGatherNd : Pat<(TF_GatherNdOp $params, $indices), + (TFL_GatherNdOp $params, $indices)>; -def : Pat<(TF_GatherV2Op $params, $indices, - (ConstantOp ElementsAttr:$axis), - ConstantAttr:$batch_dims), - (TFL_GatherOp $params, $indices, - ExtractSingleElementAsInt32:$axis)>; +def LegalizeGatherV2 : Pat< + (TF_GatherV2Op $params, $indices, (ConstantOp ElementsAttr:$axis), + ConstantAttr:$batch_dims), + (TFL_GatherOp $params, $indices, ExtractSingleElementAsInt32:$axis)>; -def : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>; +def LegalizeFloorDiv : Pat<(TF_FloorDivOp $l, $r), (TFL_FloorDivOp $l, $r)>; -def : Pat<(TF_NotEqualOp $l, $r, /*incompatible_shape_error=*/ConstBoolAttrTrue), - (TFL_NotEqualOp $l, $r)>; +def LegalizeNotEqual : Pat< + (TF_NotEqualOp $l, $r, /*incompatible_shape_error=*/ConstBoolAttrTrue), + (TFL_NotEqualOp $l, $r)>; -def : Pat<(TF_LogicalAndOp $l, $r), (TFL_LogicalAndOp $l, $r)>; +def LegalizeLogicalAnd : Pat<(TF_LogicalAndOp $l, $r), + (TFL_LogicalAndOp $l, $r)>; -def : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>; +def LegalizeLogicalOr : Pat<(TF_LogicalOrOp $l, $r), (TFL_LogicalOrOp $l, $r)>; + +def LegalizeAdd : Pat<(TF_AddOp $lhs, $rhs), + (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; +def LegalizeAddv2 : Pat<(TF_AddV2Op $lhs, $rhs), + (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; +def LegalizeBiasAdd : Pat< + (TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, IsDataFormatNHWC:$data_format), + (TFL_AddOp $l, $r, TFL_AF_None)>; +def LegalizeSub : Pat<(TF_SubOp $lhs, $rhs), + (TFL_SubOp $lhs, $rhs, TFL_AF_None)>; +def LegalizeMul : Pat<(TF_MulOp $lhs, $rhs), + (TFL_MulOp $lhs, $rhs, TFL_AF_None)>; +def LegalizeRealDiv : Pat<(TF_RealDivOp $lhs, $rhs), + (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; +def LegalizeDiv : Pat<(TF_DivOp $lhs, $rhs), + (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; -def : Pat<(TF_AddOp $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; -def : Pat<(TF_AddV2Op $lhs, $rhs), (TFL_AddOp $lhs, $rhs, TFL_AF_None)>; // When batch size is known, TF BatchMatMul gets unfolded to TFL FullyConnected // with additional ops. In the case of unknown batch size, the match will // fall through to here and convert to TF Lite BatchMatMul. -def : Pat<(TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; -def : Pat<(TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y), (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; -def : Pat<(TF_SubOp $lhs, $rhs), (TFL_SubOp $lhs, $rhs, TFL_AF_None)>; -def : Pat<(TF_MulOp $lhs, $rhs), (TFL_MulOp $lhs, $rhs, TFL_AF_None)>; -def : Pat<(TF_RealDivOp $lhs, $rhs), (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; -def : Pat<(TF_DivOp $lhs, $rhs), (TFL_DivOp $lhs, $rhs, TFL_AF_None)>; +def LegalizeBatchMatMulV2UnknownBatch : Pat< + (TF_BatchMatMulV2Op $lhs, $rhs, $adj_x, $adj_y), + (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; +def LegalizeBatchMatMulUnknownBatch : Pat< + (TF_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y), + (TFL_BatchMatMulOp $lhs, $rhs, $adj_x, $adj_y)>; -def : Pat<(TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, - IsDataFormatNHWC:$data_format), - (TFL_AddOp $l, $r, TFL_AF_None)>; -// TODO(jpienaar): These should be handled by the pattern rewriter, find out -// why it isn't. -def : Pat<(TF_Relu6Op (TF_BiasAddOp F32Tensor:$l, F32Tensor:$r, - IsDataFormatNHWC:$data_format)), - (TFL_AddOp $l, $r, TFL_AF_Relu6)>; - -def : Pat<(TF_FakeQuantWithMinMaxVarsOp $inputs, - (ConstantOp F32ElementsAttr:$min), - (ConstantOp F32ElementsAttr:$max), - $num_bits, $narrow_range), - (TFL_DequantizeOp - (TFL_QuantizeOp $inputs, - (ConvertToQuantTypeFromAttrs $inputs, $min, $max, - $num_bits, $narrow_range)))>; +def LegalizeFakeQuantWithMinMaxVars: Pat< + (TF_FakeQuantWithMinMaxVarsOp $inputs, (ConstantOp F32ElementsAttr:$min), + (ConstantOp F32ElementsAttr:$max), $num_bits, $narrow_range), + (TFL_DequantizeOp + (TFL_QuantizeOp $inputs, (ConvertToQuantTypeFromAttrs $inputs, $min, $max, + $num_bits, $narrow_range)))>; // TODO(rocky): Not all of the attributes are handled correctly. Make this // more general if there is a need. -def : Pat<(TF_QuantizeAndDequantizeV2Op $inputs, - (ConstantOp F32ElementsAttr:$min), - (ConstantOp F32ElementsAttr:$max), - $signed_input, $num_bits, $range_given, $round_mode, - $narrow_range, $axis), - (TFL_DequantizeOp - (TFL_QuantizeOp $inputs, - (ConvertToQuantTypeFromAttrs $inputs, $min, $max, - $num_bits, $narrow_range)))>; +def LegalizeQuantizeAndDequantizeV2 : Pat< + (TF_QuantizeAndDequantizeV2Op $inputs, (ConstantOp F32ElementsAttr:$min), + (ConstantOp F32ElementsAttr:$max), + $signed_input, $num_bits, $range_given, $round_mode, $narrow_range, $axis), + (TFL_DequantizeOp + (TFL_QuantizeOp $inputs, (ConvertToQuantTypeFromAttrs $inputs, $min, $max, + $num_bits, $narrow_range)))>; -def : Pat<(TF_RankOp $input), (TFL_RankOp $input)>; +def LegalizeRank : Pat<(TF_RankOp $input), (TFL_RankOp $input)>; -def : Pat<(TF_SquaredDifferenceOp $l, $r), (TFL_SquaredDifferenceOp $l, $r)>; +def LegalizeSquaredDifference : Pat<(TF_SquaredDifferenceOp $l, $r), + (TFL_SquaredDifferenceOp $l, $r)>; -// Note(ycling): We can eliminate Relu from Relu(SquaredDifference(x, y)), -// since the result of SquaredDifference is always non-negative. -// TFLite interpreter doesn't support Relu+int32 for now. So the test cases -// are failing without the following pattern to optimize Relu away fixes -// the problem. -def : Pat<(TF_ReluOp (TF_SquaredDifferenceOp $l, $r)), - (TFL_SquaredDifferenceOp $l, $r)>; +def LegalizeReverseV2 : Pat<(TF_ReverseV2Op $arg0, $arg1), + (TFL_ReverseV2Op $arg0, $arg1)>; -def : Pat<(TF_ReverseV2Op $arg0, $arg1), (TFL_ReverseV2Op $arg0, $arg1)>; +def LegalizeEqual : Pat<(TF_EqualOp $arg0, $arg1, + /*incompatible_shape_error=*/ConstBoolAttrTrue), + (TFL_EqualOp $arg0, $arg1)>; -def : Pat<(TF_EqualOp $arg0, $arg1, /*incompatible_shape_error=*/ConstBoolAttrTrue), (TFL_EqualOp $arg0, $arg1)>; +def LegalizePad : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>; -def : Pat<(TF_PadOp $arg0, $arg1), (TFL_PadOp $arg0, $arg1)>; +def LegalizeTile : Pat<(TF_TileOp $arg0, $arg1), (TFL_TileOp $arg0, $arg1)>; -def : Pat<(TF_TileOp $arg0, $arg1), (TFL_TileOp $arg0, $arg1)>; +def LegalizePadV2 : Pat<(TF_PadV2Op $arg0, $arg1, $cst), + (TFL_PadV2Op $arg0, $arg1, $cst)>; -def : Pat<(TF_PadV2Op $arg0, $arg1, $cst), (TFL_PadV2Op $arg0, $arg1, $cst)>; +def LegalizeMean : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), + (TFL_MeanOp $arg0, $arg1, $arg2)>; -def : Pat<(TF_MeanOp $arg0, $arg1, BoolAttr:$arg2), (TFL_MeanOp $arg0, $arg1, $arg2)>; - -def : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), (TFL_SumOp $arg, $axes, $arg2)>; +def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), + (TFL_SumOp $arg, $axes, $arg2)>; // TopK in TFL is always sorted so we ignore that attribute here. -def : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), (TFL_TopKV2Op $input, $k)>; +def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), + (TFL_TopKV2Op $input, $k)>; -def : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMinOp $arg0, $arg1, $arg2)>; +def LegalizeMin : Pat<(TF_MinOp $arg0, $arg1, BoolAttr:$arg2), + (TFL_ReduceMinOp $arg0, $arg1, $arg2)>; -def : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>; +def LegalizeMax : Pat<(TF_MaxOp $arg0, $arg1, BoolAttr:$arg2), + (TFL_ReduceMaxOp $arg0, $arg1, $arg2)>; -def : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), (TFL_ReduceProdOp $arg0, $arg1, $arg2)>; +def LegalizeProd : Pat<(TF_ProdOp $arg0, $arg1, BoolAttr:$arg2), + (TFL_ReduceProdOp $arg0, $arg1, $arg2)>; -def : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims), - (TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>; +def LegalizeAny : Pat<(TF_AnyOp $input, $reduction_indices, $keep_dims), + (TFL_ReduceAnyOp $input, $reduction_indices, $keep_dims)>; -def : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>; +def LegalizeCast : Pat<(TF_CastOp $arg0, BoolAttr:$arg1), (TFL_CastOp $arg0)>; -def : Pat<(TF_BatchToSpaceNDOp $input, $block_shape, $crops), (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>; +def LegalizeBatchToSpaceND : Pat< + (TF_BatchToSpaceNDOp $input, $block_shape, $crops), + (TFL_BatchToSpaceNdOp $input, $block_shape, $crops)>; -def : Pat<(TF_SpaceToBatchNDOp $input, $block_shape, $paddings), (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>; +def LegalizeSpaceToBatchND : Pat< + (TF_SpaceToBatchNDOp $input, $block_shape, $paddings), + (TFL_SpaceToBatchNdOp $input, $block_shape, $paddings)>; -def : Pat<(TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format), - (TFL_SpaceToDepthOp $input, (convertIntAttrTo32Bit $block_size))>; +def LegalizeSpaceToDepth : Pat< + (TF_SpaceToDepthOp $input, $block_size, IsDataFormatNHWC:$data_format), + (TFL_SpaceToDepthOp $input, (convertIntAttrTo32Bit $block_size))>; -def : Pat<(TF_DepthToSpaceOp $input, $block_size, IsDataFormatNHWC:$data_format), - (TFL_DepthToSpaceOp $input, (convertIntAttrTo32Bit $block_size))>; +def LegalizeDepthToSpace : Pat< + (TF_DepthToSpaceOp $input, $block_size, IsDataFormatNHWC:$data_format), + (TFL_DepthToSpaceOp $input, (convertIntAttrTo32Bit $block_size))>; -def : Pat<(TF_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers), (TFL_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers)>; -def : Pat<(TF_ResizeNearestNeighborOp $images, $size, $align_corners, $half_pixel_centers), (TFL_ResizeNearestNeighborOp $images, $size, $align_corners, $half_pixel_centers)>; +def LegalizeResizeBilinear : Pat< + (TF_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers), + (TFL_ResizeBilinearOp $images, $size, $align_corners, $half_pixel_centers)>; +def LegalizeResizeNearestNeighbor : Pat< + (TF_ResizeNearestNeighborOp $images, $size, $align_corners, + $half_pixel_centers), + (TFL_ResizeNearestNeighborOp $images, $size, $align_corners, + $half_pixel_centers)>; -def : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), (TFL_MirrorPadOp $arg0, $arg1, $cst)>; +def LegalizeMirrorPad : Pat<(TF_MirrorPadOp $arg0, $arg1, $cst), + (TFL_MirrorPadOp $arg0, $arg1, $cst)>; -def : Pat<(TF_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, $default_value, $validate_indices), - (TFL_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, $default_value)>; +def LegalizeSparseToDense : Pat< + (TF_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, + $default_value, $validate_indices), + (TFL_SparseToDenseOp $sparse_indices, $output_shape, $sparse_values, + $default_value)>; -def : Pat<(TF_UniqueOp $arg0),(TFL_UniqueOp $arg0)>; +def LegalizeUnique : Pat<(TF_UniqueOp $arg0),(TFL_UniqueOp $arg0)>; -def : Pat<(TF_FloorModOp $arg0, $arg1), (TFL_FloorModOp $arg0, $arg1)>; -def : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>; +def LegalizeFloorMod : Pat<(TF_FloorModOp $arg0, $arg1), + (TFL_FloorModOp $arg0, $arg1)>; +def LegalizeExp : Pat<(TF_ExpOp $arg0), (TFL_ExpOp $arg0)>; -def : Pat<(TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), $bias, $alpha, $beta)>; +def LegalizeLRN : Pat< + (TF_LRNOp $arg0, $radius, F32Attr:$bias, F32Attr:$alpha, F32Attr:$beta), + (TFL_LocalResponseNormalizationOp $arg0, (convertIntAttrTo32Bit $radius), + $bias, $alpha, $beta)>; -def : Pat< - (TF_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $pad_to_max_output_size), - (TFL_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold)>; +def LegalizeNonMaxSuppressionV4 : Pat< + (TF_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, + $score_threshold, $pad_to_max_output_size), + (TFL_NonMaxSuppressionV4Op $boxes, $scores, $max_output_size, $iou_threshold, + $score_threshold)>; -def : Pat< - (TF_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma, $pad_to_max_output_size), - (TFL_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, $score_threshold, $soft_nms_sigma)>; +def LegalizeNonMaxSuppressionV5 : Pat< + (TF_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, + $score_threshold, $soft_nms_sigma, $pad_to_max_output_size), + (TFL_NonMaxSuppressionV5Op $boxes, $scores, $max_output_size, $iou_threshold, + $score_threshold, $soft_nms_sigma)>; -def : Pat<(TF_MatrixDiagOp $diagonal), (TFL_MatrixDiagOp $diagonal)>; +def LegalizeMatrixDiag : Pat<(TF_MatrixDiagOp $diagonal), + (TFL_MatrixDiagOp $diagonal)>; class I32VectorElementsAttr : ElementsAttrBase< CPred<"$_self.isa() &&" @@ -356,7 +406,7 @@ class I32VectorElementsAttr : ElementsAttrBase< "RankedTensorType::get({" # len # "}, $_builder.getIntegerType(32)), $0)"; } -def : Pat< +def LegalizeConv2DBackpropInput : Pat< (TF_Conv2DBackpropInputOp $input_sizes, $filter, $out_backprop, IsIntList1XY1:$strides, BoolAttr:$use_cudnn_on_gpu, @@ -373,9 +423,10 @@ def : Pat< /*stride_h=*/ ExtractI32At<1>:$strides, /*stride_w=*/ ExtractI32At<2>:$strides)>; -def : Pat< +def LegalizeMatrixSetDiag : Pat< (TF_MatrixSetDiagOp $input, $diagonal), (TFL_MatrixSetDiagOp $input, $diagonal)>; -def : Pat<(TF_ScatterNdOp I32Tensor:$indices, $updates, $shape), - (TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>; +def LegalizeScatterNd : Pat< + (TF_ScatterNdOp I32Tensor:$indices, $updates, $shape), + (TFL_ScatterNdOp I32Tensor:$indices, $updates, $shape)>; diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 1328a2baf5d..7a16e475ce3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -158,7 +158,7 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite( random_uniform_op.seed().getSExtValue(), random_uniform_op.seed2().getSExtValue()); Distribution dist; - int num_elements = 0; + size_t num_elements = 0; if (auto output_type = random_uniform_op.output().getType().dyn_cast_or_null()) { if (auto ranked_output = output_type.dyn_cast_or_null()) { diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc index 31e3f6dd005..6202507ae91 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf_while.cc @@ -49,23 +49,19 @@ void RunOnWhile(TF::WhileOp while_op) { op->getLoc(), op->getResultTypes(), op->getOperands(), while_op.is_stateless()); // Insert call to the given function into the 'region'. - auto create_region_with_call = [&while_op](FlatSymbolRefAttr symbol, - Region& region) { + auto create_region_with_call = [&while_op](FuncOp func, Region& region) { OpBuilder builder(region); auto block = builder.createBlock(®ion); SmallVector new_operands; - auto func = while_op.getParentOfType().lookupSymbol( - symbol.getValue()); for (Type t : func.getType().getInputs()) new_operands.push_back(block->addArgument(t)); - auto call = builder.create( - while_op.getLoc(), symbol, func.getType().getResults(), new_operands); + auto call = builder.create(while_op.getLoc(), func, new_operands); builder.create(while_op.getLoc(), call.getResults()); // Mark old function as private so that it can be DCE'd if not called. func.setVisibility(SymbolTable::Visibility::Private); }; - create_region_with_call(while_op.condAttr(), new_op.cond()); - create_region_with_call(while_op.bodyAttr(), new_op.body()); + create_region_with_call(while_op.cond_func(), new_op.cond()); + create_region_with_call(while_op.body_func(), new_op.body()); op->replaceAllUsesWith(new_op.getResults()); op->erase(); diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 2498a732a86..edddc7751ab 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -17,7 +17,7 @@ limitations under the License. // converting Tensorlist operations in TensorFlow dialect into operations that // can be legalized to TensorFlow Lite dialect with simple replacements. The // newly created operations are in the TensorFlow dialect if the operation can -// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op +// be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op // is used. #include @@ -332,9 +332,8 @@ struct ConvertTensorListInitOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Type dtype = op.element_dtype(); if (!(dtype.isF16() || dtype.isF32() || dtype.isF64() || - dtype.isInteger(1) || dtype.isSignlessInteger(8) || - dtype.isSignlessInteger(16) || dtype.isSignlessInteger(32) || - dtype.isSignlessInteger(64))) { + dtype.isInteger(1) || dtype.isInteger(8) || dtype.isInteger(16) || + dtype.isInteger(32) || dtype.isInteger(64))) { op.emitError( "requires element_dtype to be 1-bit/8-bit/16-bit/32-bit/64-bit " "integer or 16-bit/32-bit/64-bit float type during TF Lite " @@ -739,14 +738,18 @@ struct ConvertIdentity : public OpConversionPattern { } }; +// Returns an unranked tensor type with an element of the same type as `value` +// if `type` is a tensor of variant. Otherwise, returns `type` unmodified. +Type VariantToUnrankedTensorType(Type type, Value value) { + if (getElementTypeOrSelf(type).isa()) + return UnrankedTensorType::get(getElementTypeOrSelf(value.getType())); + return type; +} + // Changes the function type of `cond_func` and `body_func` for the given While // op. -static LogicalResult UpdateFunctionTypes(TF::WhileOp op) { - auto module = op.getParentOfType(); - auto *context = module.getContext(); - - for (StringRef func_name : {op.cond(), op.body()}) { - FuncOp func = module.lookupSymbol(func_name); +LogicalResult UpdateFunctionTypes(TF::WhileOp op) { + for (FuncOp func : {op.cond_func(), op.body_func()}) { if (!func) continue; FunctionType func_type = func.getType(); @@ -757,42 +760,29 @@ static LogicalResult UpdateFunctionTypes(TF::WhileOp op) { // tensor type if it's a variant type. SmallVector updated_argument_types; updated_argument_types.reserve(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - Type arg_type = func_type.getInput(i); - if (getElementTypeOrSelf(arg_type).isa()) { - arg_type = UnrankedTensorType::get( - getElementTypeOrSelf(op.getOperand(i).getType())); - } - updated_argument_types.push_back(arg_type); - } + for (auto it : llvm::zip(func_type.getInputs(), op.getOperands())) + updated_argument_types.push_back( + VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it))); - // For each result type in function's results, change it to unranked tensor - // type if it's a variant type. + // Change all DT_VARIANT result types in function results to unranked tensor + // type with element type derived from the corresponding input operand. This + // is correct because while body's inputs and results have the same type. SmallVector updated_result_types; updated_result_types.reserve(num_results); - for (int i = 0; i < num_results; ++i) { - Type result_type = func_type.getResult(i); - if (getElementTypeOrSelf(result_type).isa()) { - // Here update the variant type with the unranked tensor type derived - // from the corresponding input operand. This is correct because while - // body's inputs and results have the same type. - result_type = UnrankedTensorType::get( - getElementTypeOrSelf(op.getOperand(i).getType())); - } - updated_result_types.push_back(result_type); - } + for (auto it : llvm::zip(func_type.getResults(), op.getOperands())) + updated_result_types.push_back( + VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it))); // Change `func`'s argument type to `unranked_argument_types`. If it // return types contain a `DT_VARIANT`, change it to the unranked type // derived from the corresponding argument. func.setType(FunctionType::get(updated_argument_types, updated_result_types, - context)); + op.getContext())); // Change the argument type for the first block. - Block &body_first_bb = func.front(); - for (int i = 0; i < body_first_bb.getNumArguments(); ++i) { - body_first_bb.getArgument(i).setType(updated_argument_types[i]); - } + llvm::for_each(func.getArguments(), [&](BlockArgument &arg) { + arg.setType(updated_argument_types[arg.getArgNumber()]); + }); } return success(); } @@ -805,25 +795,60 @@ struct ConvertWhile : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { llvm::SmallVector result_types; result_types.reserve(op.getNumOperands()); - for (int i = 0, e = operands.size(); i != e; ++i) { - Type result_ty = op.getResult(i).getType(); + // Change all DT_VARIANT result types to unranked tensor type. + for (auto it : llvm::zip(op.getResultTypes(), operands)) + result_types.push_back( + VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it))); - // If we notice the result type is a DT_VARIANT, we change the - // corresponding result type to unranked tensor type. - if (getElementTypeOrSelf(result_ty).isa()) { - Type element_ty = getElementTypeOrSelf(operands[i].getType()); - result_ty = UnrankedTensorType::get(element_ty); + // Create a new while op with new operands and updated result types. + auto converted = rewriter.create(op.getLoc(), result_types, + operands, op.getAttrs()); + converted.removeAttr("T"); + UpdateFunctionTypes(converted); + + rewriter.replaceOp(op, converted.getResults()); + return success(); + } +}; + +struct ConvertWhileRegion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::WhileRegionOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector result_types; + result_types.reserve(op.getNumOperands()); + // Change all DT_VARIANT result types to unranked tensor type. + for (auto it : llvm::zip(op.getResultTypes(), operands)) + result_types.push_back( + VariantToUnrankedTensorType(std::get<0>(it), std::get<1>(it))); + + // Create a new while op with new operands and updated result types. + auto converted = rewriter.create( + op.getLoc(), result_types, operands, op.getAttrs()); + + // Inline the regions from the old while into the new one, and apply + // signature conversion to inlined region. + for (auto it : llvm::zip(op.getRegions(), converted.getRegions())) { + Region &old_region = *std::get<0>(it); + Region &new_region = *std::get<1>(it); + + Block &entry = old_region.front(); + // Build signature conversion for the region. + TypeConverter::SignatureConversion signature_conversion(operands.size()); + for (auto it : llvm::zip(entry.getArguments(), operands)) { + BlockArgument arg = std::get<0>(it); + signature_conversion.addInputs( + arg.getArgNumber(), + VariantToUnrankedTensorType(arg.getType(), std::get<1>(it))); } - result_types.push_back(result_ty); + + rewriter.inlineRegionBefore(old_region, new_region, new_region.end()); + rewriter.applySignatureConversion(&new_region, signature_conversion); } - // Clone original while op with new operands and updated result types. - auto cloned = rewriter.create(op.getLoc(), result_types, - operands, op.getAttrs()); - cloned.removeAttr("T"); - UpdateFunctionTypes(cloned); - - rewriter.replaceOp(op, cloned.getResults()); + rewriter.replaceOp(op, converted.getResults()); return success(); } }; @@ -872,7 +897,8 @@ LogicalResult LowerStaticTensorListPass::RewriteFunction( ConvertTensorListGetItem, ConvertTensorListLength, ConvertTensorListPushBack, ConvertTensorListReserve, ConvertTensorListSetItem, ConvertTensorListStack, - ConvertTensorListResize, ConvertWhile>(context); + ConvertTensorListResize, ConvertWhile, ConvertWhileRegion>( + context); return applyPartialConversion(func, target, patterns); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index d26a4906420..eeecfac67cf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -160,6 +160,31 @@ bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, return false; } +// Retuns true if we can eliminate the GatherNdOp or ScatterNdOp. When the value +// of `indices` are from 0 to n-1, the output tensor are identical to the +// `params`. +bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params, + DenseIntElementsAttr indices) { + auto params_type = params.getType().dyn_cast(); + auto indices_type = indices.getType().dyn_cast(); + // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D + // `indices` means it gets the first row of `params`. As long as indices + // iterate the first row of `params`, the output is identical to input. + if (!params_type || !indices_type || indices_type.getRank() != 2 || + indices_type.getDimSize(0) != params_type.getDimSize(0) || + indices_type.getDimSize(1) != 1) + return false; + + // Checks the value in `indices` is from 0 to n-1. + int cur_value = 0; + for (const auto &v : indices.getValues()) { + if (v.getSExtValue() != cur_value) return false; + ++cur_value; + } + + return true; +} + // Expand Attribute 'a' to 4D with all 1s except 1 dimension. // Which dimension depends on 'is_depthwise' is true or false. ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) { @@ -197,9 +222,10 @@ TypeAttr RescaleQtype(Type input, Attribute factor) { DenseElementsAttr GetShape(Value output_val) { auto output_type = output_val.getType().cast(); auto shape_vector = output_type.getShape(); - std::vector shape(shape_vector.size()); - for (int i = 0; i < shape_vector.size(); ++i) { - shape[i] = shape_vector[i]; + std::vector shape; + shape.reserve(shape_vector.size()); + for (auto shape_object : shape_vector) { + shape.push_back(shape_object); } return mlir::DenseElementsAttr::get( RankedTensorType::get( @@ -684,7 +710,7 @@ struct ConvertTrivialTransposeOpToReshapeOp SmallVector old_major_index_ordering; SmallVector new_major_index_ordering; - for (int i = 0; i < input_shape.size(); i++) { + for (int i = 0, end = input_shape.size(); i < end; i++) { if (input_shape[i] != 1) { old_major_index_ordering.push_back(i); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index 18c1912d4c7..2311ae0668c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -83,16 +83,15 @@ class FoldIfOp : public OpRewritePattern { if (!llvm::hasSingleElement(parent_op)) return failure(); // Find the then and else branch functions. - SymbolTable table(op.getParentOfType()); - FuncOp then_branch = table.lookup(op.then_branch()); - FuncOp else_branch = table.lookup(op.else_branch()); + FuncOp then_func = op.then_func(); + FuncOp else_func = op.else_func(); // If the If has no uses and its functions are side-effect free, then // remove. // TODO(jpienaar): Remove once recusive side-effects are supported. if (op.use_empty() && (op.is_stateless() || - (IsSideEffectFree(then_branch) && IsSideEffectFree(else_branch)))) { + (IsSideEffectFree(then_func) && IsSideEffectFree(else_func)))) { rewriter.eraseOp(op.getOperation()); return success(); } @@ -109,7 +108,7 @@ class FoldIfOp : public OpRewritePattern { // Identify the branch to inline. bool cond_value = (*cond.int_value_begin()).getSExtValue(); - FuncOp func = cond_value ? then_branch : else_branch; + FuncOp func = cond_value ? then_func : else_func; // Make sure that the function has exactly one block to simplify inlining. // TFLite doesn't use control flow with blocks so functions with more than diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 1fae567c835..3c5fc7a0c5e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -37,22 +37,19 @@ class HasRankAtMost : Constraint< // Multi-pattern consisting of matching stand-alone convolution op followed by // activation op. multiclass FuseActFnIntoConvOpPat { - def : Pat<(ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, - $h_factor, $w_factor, TFL_AF_None, - $padding, $stride_h, $stride_w)), - (TFL_Conv2DOp $input, $filter, $bias, - $h_factor, $w_factor, ActFnAttr, - $padding, $stride_h, $stride_w), - [(HasOneUse $conv_out)]>; - def : Pat<(ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, - $h_factor, $w_factor, TFL_AF_None, - $padding, $stride_h, $stride_w, - $multiplier)), - (TFL_DepthwiseConv2DOp $input, $filter, $bias, - $h_factor, $w_factor, ActFnAttr, - $padding, $stride_h, $stride_w, - $multiplier), - [(HasOneUse $conv_out)]>; + def FuseActivationFuncWithConv#ActFnOp#ActFnAttr : Pat< + (ActFnOp (TFL_Conv2DOp:$conv_out $input, $filter, $bias, $h_factor, + $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w)), + (TFL_Conv2DOp $input, $filter, $bias, $h_factor, $w_factor, ActFnAttr, + $padding, $stride_h, $stride_w), + [(HasOneUse $conv_out)]>; + def FuseActivationFuncWithDepthwiseConv#ActFnOp#ActFnAttr : Pat< + (ActFnOp (TFL_DepthwiseConv2DOp:$conv_out $input, $filter, $bias, $h_factor, + $w_factor, TFL_AF_None, $padding, $stride_h, $stride_w, + $multiplier)), + (TFL_DepthwiseConv2DOp $input, $filter, $bias, $h_factor, $w_factor, + ActFnAttr, $padding, $stride_h, $stride_w, $multiplier), + [(HasOneUse $conv_out)]>; } // TODO(hinsu): Also fuse ops corresponding to SIGN_BIT fused @@ -73,33 +70,29 @@ class CanFuseConvOrDepthwiseConv : Constraint< // constant folding the bias and the binary op's constant operand. The following // pattern restricts to float constant values for now. multiclass FuseBinaryOpToPrecedingAffine { - def : Pat<(binaryOp (TFL_Conv2DOp:$output $input, $filter, - (ConstantOp F32ElementsAttr:$bias), - $h_factor, $w_factor, TFL_AF_None, - $padding, $stride_h, $stride_w), - (ConstantOp F32ElementsAttr:$value), $act_fn), - (TFL_Conv2DOp $input, $filter, - (binaryOp (ConstantOp $bias), - (ConstantOp $value), TFL_AF_None), - $h_factor, $w_factor, $act_fn, - $padding, $stride_h, $stride_w), - [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), - (HasOneUse $output)]>; - def : Pat<(binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter, - (ConstantOp F32ElementsAttr:$bias), - $h_factor, $w_factor, TFL_AF_None, - $padding, $stride_h, $stride_w, - $multiplier), - (ConstantOp F32ElementsAttr:$value), $act_fn), - (TFL_DepthwiseConv2DOp $input, $filter, - (binaryOp (ConstantOp $bias), - (ConstantOp $value), - TFL_AF_None), - $h_factor, $w_factor, $act_fn, - $padding, $stride_h, $stride_w, - $multiplier), - [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), - (HasOneUse $output)]>; + def FuseBinaryOpWithConv#binaryOp : Pat< + (binaryOp (TFL_Conv2DOp:$output $input, $filter, + (ConstantOp F32ElementsAttr:$bias), $h_factor, $w_factor, + TFL_AF_None, $padding, $stride_h, $stride_w), + (ConstantOp F32ElementsAttr:$value), $act_fn), + (TFL_Conv2DOp $input, $filter, + (binaryOp (ConstantOp $bias), + (ConstantOp $value), TFL_AF_None), + $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), + [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), + (HasOneUse $output)]>; + def FuseBinaryOpWithDepthwiseConv#binaryOp : Pat< + (binaryOp (TFL_DepthwiseConv2DOp:$output $input, $filter, + (ConstantOp F32ElementsAttr:$bias), + $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, + $stride_w, $multiplier), + (ConstantOp F32ElementsAttr:$value), $act_fn), + (TFL_DepthwiseConv2DOp $input, $filter, + (binaryOp (ConstantOp $bias), (ConstantOp $value), TFL_AF_None), + $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w, + $multiplier), + [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), + (HasOneUse $output)]>; } foreach binaryOp = [TFL_AddOp, TFL_SubOp] in defm : FuseBinaryOpToPrecedingAffine; @@ -116,43 +109,43 @@ def ExpandTo4DForDepthwiseConv: NativeCodeCall< // The following pattern restricts to float constant values for now. multiclass FuseMulOrDivWithConv2dOrDepthwiseConv2d { - def : Pat<(BinaryOp (TFL_DepthwiseConv2DOp:$output $input, - (ConstantOp F32ElementsAttr:$filter), - (ConstantOp F32ElementsAttr:$bias), - $h_factor, $w_factor, TFL_AF_None, - $padding, $stride_h, $stride_w, - $multiplier), - (ConstantOp F32ElementsAttr:$value), $act_fn), - (TFL_DepthwiseConv2DOp $input, - (BinaryOp (ConstantOp $filter), - (ConstantOp - (ExpandTo4DForDepthwiseConv $value)), - TFL_AF_None), - (BinaryOp (ConstantOp $bias), - (ConstantOp $value), - TFL_AF_None), - $h_factor, $w_factor, $act_fn, - $padding, $stride_h, $stride_w, - $multiplier), - [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), - (HasOneUse $output)]>; - def : Pat<(BinaryOp (TFL_Conv2DOp:$conv_output $input, - (ConstantOp F32ElementsAttr:$filter), - (ConstantOp F32ElementsAttr:$bias), - $h_factor, $w_factor, TFL_AF_None, - $padding, $stride_h, $stride_w), - (ConstantOp F32ElementsAttr:$value), $act_fn), - (TFL_Conv2DOp $input, - (BinaryOp (ConstantOp $filter), - (ConstantOp (ExpandTo4DForConv $value)), - TFL_AF_None), - (BinaryOp (ConstantOp $bias), - (ConstantOp $value), - TFL_AF_None), - $h_factor, $w_factor, $act_fn, - $padding, $stride_h, $stride_w), - [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), - (HasOneUse $conv_output)]>; + def FuseMulOrDivWithDepthwiseConv#BinaryOp : Pat< + (BinaryOp (TFL_DepthwiseConv2DOp:$output $input, + (ConstantOp F32ElementsAttr:$filter), + (ConstantOp F32ElementsAttr:$bias), + $h_factor, $w_factor, TFL_AF_None, $padding, $stride_h, + $stride_w, $multiplier), + (ConstantOp F32ElementsAttr:$value), $act_fn), + (TFL_DepthwiseConv2DOp $input, + (BinaryOp + (ConstantOp $filter), + (ConstantOp (ExpandTo4DForDepthwiseConv $value)), + TFL_AF_None), + (BinaryOp + (ConstantOp $bias), + (ConstantOp $value), + TFL_AF_None), + $h_factor, $w_factor, $act_fn, $padding, $stride_h, + $stride_w, $multiplier), + [(CanFuseConvOrDepthwiseConv<"true"> $filter, $value), + (HasOneUse $output)]>; + def FuseMulOrDivWithConv#BinaryOp : Pat< + (BinaryOp (TFL_Conv2DOp:$conv_output $input, + (ConstantOp F32ElementsAttr:$filter), + (ConstantOp F32ElementsAttr:$bias), + $h_factor, $w_factor, TFL_AF_None, + $padding, $stride_h, $stride_w), + (ConstantOp F32ElementsAttr:$value), $act_fn), + (TFL_Conv2DOp $input, + (BinaryOp (ConstantOp $filter), + (ConstantOp (ExpandTo4DForConv $value)), + TFL_AF_None), + (BinaryOp (ConstantOp $bias), + (ConstantOp $value), + TFL_AF_None), + $h_factor, $w_factor, $act_fn, $padding, $stride_h, $stride_w), + [(CanFuseConvOrDepthwiseConv<"false"> $filter, $value), + (HasOneUse $conv_output)]>; } foreach BinaryOp = [TFL_DivOp, TFL_MulOp] in @@ -177,7 +170,7 @@ class OperandHasRank : Constraint< CPred<"$0.getType().cast().getRank() == " # n>>; // Matching HardSwish -def : Pat< +def MatchHardSwishPattern1 : Pat< (TFL_MulOp (TFL_MulOp $x, (TFL_AddOp @@ -190,7 +183,7 @@ def : Pat< (TFL_HardSwishOp $x), [(EqualOperands $x, $y)]>; -def : Pat< +def MatchHardSwishPattern2 : Pat< (TFL_MulOp $x, (TFL_MulOp @@ -207,7 +200,7 @@ def : Pat< // Matching HardSwish with extra FakeQuant. These FakeQuant ops were due to // incorrect placement in the quantization aware training. // TODO(b/149735743): We should make the placement automatically. -def : Pat< +def MatchHardSwishQuantized : Pat< (TFL_MulOp (TFL_DequantizeOp (TFL_QuantizeOp (TFL_MulOp $x, (TFL_DequantizeOp (TFL_QuantizeOp (TFL_AddOp @@ -238,7 +231,8 @@ multiclass L2NormalizePatterns { // This pattern constructs L2NormalizationOp from // Mul->Rsqrt->Sum->Square Or // Div->sqrt->Sum->Square - def : Pat<(FirstOp $operand1, + def L2NormalizePattern1#FirstOp#SecondOp : Pat< + (FirstOp $operand1, (SecondOp (TFL_SumOp (TFL_SquareOp:$sq_op $square_operand), @@ -251,7 +245,8 @@ multiclass L2NormalizePatterns { // Below patterns for L2Normalize when there is an Add or Maximum // adding or clamping to a small constant scalar. - def : Pat<(FirstOp $operand1, + def L2NormalizePattern2#FirstOp#SecondOp : Pat< + (FirstOp $operand1, (SecondOp (TFL_AddOp (TFL_SumOp @@ -265,7 +260,8 @@ multiclass L2NormalizePatterns { (L2NormValidReduceIndex $sq_op, $axis), (ConstDoubleValueLessThan<"1e-3"> $epsilon)]>; - def : Pat<(FirstOp $operand1, + def L2NormalizePattern3#FirstOp#SecondOp : Pat< + (FirstOp $operand1, (SecondOp (TFL_MaximumOp (TFL_SumOp @@ -302,14 +298,16 @@ def HaveSameType : Constraint>; // Pattern for skipping Tile if it is mainly for broadcasting and the // Op is already supporting broadcasting. multiclass FuseTileBroadcastIntoFollowingBinary { - def : Pat<(BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)), - $operand, $act_func), - (BinaryOp $input, $operand, $act_func), + def FuseTileBroadcastToBinaryOp1#BinaryOp : Pat< + (BinaryOp:$result (TFL_TileOp $input, (ConstantOp $tile)), + $operand, $act_func), + (BinaryOp $input, $operand, $act_func), [(OperandsBroadcastToOutputType $input, $operand, $result)]>; - def : Pat<(BinaryOp:$result $operand, - (TFL_TileOp $input, (ConstantOp $tile)), $act_func), - (BinaryOp $operand, $input, $act_func), + def FuseTileBroadcastToBinaryOp2#BinaryOp : Pat< + (BinaryOp:$result $operand, + (TFL_TileOp $input, (ConstantOp $tile)), $act_func), + (BinaryOp $operand, $input, $act_func), [(OperandsBroadcastToOutputType $operand, $input, $result)]>; } @@ -318,9 +316,10 @@ multiclass FusedBinaryActivationFuncOpPat { foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], [TFL_Relu6Op, TFL_AF_Relu6], [TFL_Relu1Op, TFL_AF_Relu1]] in { - def : Pat<(actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)), - (BinaryOp $lhs, $rhs, actFnPair[1]), - [(HasOneUse $binary_out)]>; + def FuseBinaryWithActivation#BinaryOp#actFnPair[0] : Pat< + (actFnPair[0] (BinaryOp:$binary_out $lhs, $rhs, TFL_AF_None)), + (BinaryOp $lhs, $rhs, actFnPair[1]), + [(HasOneUse $binary_out)]>; } } @@ -340,21 +339,22 @@ foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_DivOp, TFL_MulOp] in { // transformation, the shape of the binary op result is [40x1600], which // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to // make sure $rhs is the tail shape of $lhs. - def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), - (ConstantOp:$rhs $a), TFL_AF_None), - (TFL_ReshapeOp (BinaryOp $input, $rhs, TFL_AF_None), $shape), - // The broadcasting of "BinaryOp" only happens in the lower - // dimensions, and the higher dimensions are same, so we know the - // result and input of the "BinaryOp" in the source pattern have - // the same shape, which is defined by `shape`. - [(IsTailOfShape $rhs, $lhs), - (HasOneUse $lhs), - // The result of the new "BinaryOp" will have the same shape as - // `input`. In other words, the shape of the `Reshape` op are not - // changed after the transformation. - (IsTailOfShape $rhs, $input), - (HasRankAtMost<5> $input), - (HasRankAtMost<5> $rhs)]>; + def MoveBinaryOpBeforeReshape#BinaryOp : Pat< + (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), + (ConstantOp:$rhs $a), $act_fn), + (TFL_ReshapeOp (BinaryOp $input, $rhs, $act_fn), $shape), + // The broadcasting of "BinaryOp" only happens in the lower + // dimensions, and the higher dimensions are same, so we know the + // result and input of the "BinaryOp" in the source pattern have + // the same shape, which is defined by `shape`. + [(IsTailOfShape $rhs, $lhs), + (HasOneUse $lhs), + // The result of the new "BinaryOp" will have the same shape as + // `input`. In other words, the shape of the `Reshape` op are not + // changed after the transformation. + (IsTailOfShape $rhs, $input), + (HasRankAtMost<5> $input), + (HasRankAtMost<5> $rhs)]>; } foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, @@ -370,19 +370,20 @@ foreach BinaryOp = [TFL_FloorDivOp, TFL_FloorModOp, TFL_MinimumOp, // transformation, the shape of the binary op result is [40x1600], which // couldn't be reshaped to [1,40,40]. `IsTailOfShape` constraint is added to // make sure $rhs is the tail shape of $lhs. - def : Pat<(BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), - (ConstantOp:$rhs $a)), - (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape), - // The broadcasting of "BinaryOp" only happens in the lower - // dimensions, and the higher dimensions are same, so we know the - // result and input of the "BinaryOp" in the source pattern have - // the same shape, which is defined by `shape`. - [(IsTailOfShape $rhs, $lhs), - (HasOneUse $lhs), - // The result of the new "BinaryOp" will have the same shape as - // `input`. In other words, the shape of the `Reshape` op are not - // changed after the transformation. - (IsTailOfShape $rhs, $input)]>; + def MoveBinaryOpBeforeReshape#BinaryOp : Pat< + (BinaryOp (TFL_ReshapeOp:$lhs $input, (ConstantOp:$shape $s)), + (ConstantOp:$rhs $a)), + (TFL_ReshapeOp (BinaryOp $input, $rhs), $shape), + // The broadcasting of "BinaryOp" only happens in the lower + // dimensions, and the higher dimensions are same, so we know the + // result and input of the "BinaryOp" in the source pattern have + // the same shape, which is defined by `shape`. + [(IsTailOfShape $rhs, $lhs), + (HasOneUse $lhs), + // The result of the new "BinaryOp" will have the same shape as + // `input`. In other words, the shape of the `Reshape` op are not + // changed after the transformation. + (IsTailOfShape $rhs, $input)]>; } // Reorder the element-wise value operations and the element move operations, @@ -392,9 +393,10 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, TFL_TanhOp, TFL_SqrtOp, TFL_SquareOp] in { foreach MoveOp = [TFL_DepthToSpaceOp, TFL_ExpandDimsOp, TFL_SqueezeOp, TFL_ReshapeOp, TFL_TransposeOp] in { - def : Pat<(ValueOp:$value (MoveOp:$move $input, $move_def)), - (MoveOp (ValueOp $input), $move_def), - [(HasOneUse $move)]>; + def ReorderElementwiseAndMoveOperations#ValueOp#MoveOp : Pat< + (ValueOp:$value (MoveOp:$move $input, $move_def)), + (MoveOp (ValueOp $input), $move_def), + [(HasOneUse $move)]>; } } @@ -402,17 +404,20 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, // if called without a ranked tensor it will fail. def GetShape: NativeCodeCall<"GetShape($0)">; -// Convert squeeze to reshape if possible. -def : Pat<(TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), - (TFL_ReshapeOp $input, - (ConstantOp (GetShape $squeeze_op))), - [(AnyStaticShapeTensor $squeeze_op)]>; +// Returns True if the operand type is RankedTensorType. +def HasRankedTensor : Constraint< + CPred<"$0.getType().isa()">>; + +def ConvertSqueezeToReshape : Pat< + (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), + (TFL_ReshapeOp $input, (ConstantOp (GetShape $squeeze_op))), + [(HasRankedTensor $squeeze_op)]>; // Convert expand_dims to reshape if possible. -def : Pat<(TFL_ExpandDimsOp:$expand_dims_op $input, $dim), - (TFL_ReshapeOp $input, - (ConstantOp (GetShape $expand_dims_op))), - [(AnyStaticShapeTensor $expand_dims_op)]>; +def ConvertExpandDimsToReshape : Pat< + (TFL_ExpandDimsOp:$expand_dims_op $input, $dim), + (TFL_ReshapeOp $input, (ConstantOp (GetShape $expand_dims_op))), + [(AnyStaticShapeTensor $expand_dims_op)]>; class FloatValueEquals : Constraint().getNumElements() == 1 &&" @@ -420,25 +425,32 @@ class FloatValueEquals : Constraint().getValues().begin() == " # val>>; // ReLU patterns -def : Pat<(TFL_MinimumOp (TFL_MaximumOp $input, - (ConstantOp $NegOne)), - (ConstantOp $One)), - (TFL_Relu1Op $input), - [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; +def MatchReluPattern : Pat< + (TFL_MaximumOp $input, (ConstantOp $Zero)), + (TFL_ReluOp $input), + [(FloatValueEquals<"0"> $Zero)]>; -def : Pat<(TFL_MaximumOp (TFL_MinimumOp $input, - (ConstantOp $One)), - (ConstantOp $NegOne)), - (TFL_Relu1Op $input), - [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; +def MatchRelu1Pattern1 : Pat< + (TFL_MinimumOp (TFL_MaximumOp $input, (ConstantOp $NegOne)), + (ConstantOp $One)), + (TFL_Relu1Op $input), + [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; -def : Pat<(TFL_MaximumOp (TFL_MulOp:$mul_out $input1, - (ConstantOp F32ElementsAttr:$alpha), TFL_AF_None), - $input2), - (TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha), - [(ConstDoubleValueLessThan<"1"> $alpha), - (EqualOperands $input1, $input2), - (HasOneUse $mul_out)]>; +def MatchRelu1Pattern2 : Pat< + (TFL_MaximumOp (TFL_MinimumOp $input, (ConstantOp $One)), + (ConstantOp $NegOne)), + (TFL_Relu1Op $input), + [(FloatValueEquals<"-1"> $NegOne), (FloatValueEquals<"1"> $One)]>; + +def MatchLeakyRelu : Pat< + (TFL_MaximumOp + (TFL_MulOp:$mul_out $input1, + (ConstantOp F32ElementsAttr:$alpha), TFL_AF_None), + $input2), + (TFL_LeakyReluOp $input1, ExtractSingleElementAsFloat:$alpha), + [(ConstDoubleValueLessThan<"1"> $alpha), + (EqualOperands $input1, $input2), + (HasOneUse $mul_out)]>; def RemoveTrivialCast : Pat<(TFL_CastOp:$output $input), (replaceWithValue $input), @@ -451,23 +463,25 @@ def PReluAlphaRankCheck : Constraint< // PReLU pattern from Keras: // f(x) = Relu(x) + (-alpha * Relu(-x)) -def : Pat<(TFL_AddOp - (TFL_ReluOp:$relu_out $input1), - (TFL_MulOp:$mul_out - (TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)), - $neg_alpha, - TFL_AF_None), - TFL_AF_None), - (TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)), - [(EqualOperands $input1, $input2), - (PReluAlphaRankCheck $neg_alpha, $input1), - (HasOneUse $relu_out), - (HasOneUse $mul_out), - (HasOneUse $input_neg_out)]>; +def MatchPRelu : Pat< + (TFL_AddOp + (TFL_ReluOp:$relu_out $input1), + (TFL_MulOp:$mul_out + (TFL_ReluOp (TFL_NegOp:$input_neg_out $input2)), + $neg_alpha, + TFL_AF_None), + TFL_AF_None), + (TFL_PReluOp $input1, (TFL_NegOp $neg_alpha)), + [(EqualOperands $input1, $input2), + (PReluAlphaRankCheck $neg_alpha, $input1), + (HasOneUse $relu_out), + (HasOneUse $mul_out), + (HasOneUse $input_neg_out)]>; // The constant folding in this pass might produce constant in the tf dialect. // This rule is to legalize these constant to the tfl dialect. -def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; +def LegalizeConstOp : Pat< + (TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; // Reorders adds to allow constant folding. // Add --> Add $input, $constantA @@ -476,13 +490,49 @@ def : Pat<(TF_ConstOp ElementsAttr:$value), (TFL_ConstOp $value)>; // Add --> $input // \--> Add ($constantA, $constantB) foreach ActFun = [TFL_AF_Relu, TFL_AF_Relu6, TFL_AF_Relu1, TFL_AF_None] in { - def : Pat<(TFL_AddOp - (TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None), - (ConstantOp $b), ActFun), - (TFL_AddOp $input, - (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None), - ActFun), - [(HasOneUse $first_output)]>; + def ReorderAddToAllowConstFold_ActFunc_#ActFun : Pat< + (TFL_AddOp + (TFL_AddOp:$first_output $input, (ConstantOp $a), TFL_AF_None), + (ConstantOp $b), ActFun), + (TFL_AddOp $input, + (TFL_AddOp (ConstantOp $a), (ConstantOp $b), TFL_AF_None), + ActFun), + [(HasOneUse $first_output)]>; } +// We can eliminate Relu from Relu(SquaredDifference(x, y)), +// since the result of SquaredDifference is always non-negative. +// TFLite interpreter doesn't support Relu+int32 for now. So the test cases +// are failing without the following pattern to optimize Relu away fixes +// the problem. +def OptimizeReluSquaredDifference : Pat< + (TFL_ReluOp (TFL_SquaredDifferenceOp $l, $r)), + (TFL_SquaredDifferenceOp $l, $r)>; + +// Optimize X^1 o X +def OptimizePow1ToIdentity : Pat< + (TFL_PowOp $input, + (ConstantOp ConstantAttr, "1.0f">)), + (replaceWithValue $input)>; + +// Optimize X^2 to X*X +def OptimizePow2ToSquare : Pat< + (TFL_PowOp $input, + (ConstantOp ConstantAttr, "2.0f">)), + (TFL_MulOp $input, $input, TFL_AF_None)>; + +def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint())">>; + +def OptimizeIdentityGatherNdOp : Pat< + (TFL_GatherNdOp $params, (ConstantOp I32ElementsAttr: $indices)), + (replaceWithValue $params), + [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>; + +def OptimizeIdentityScatterNdOp : Pat< + (TFL_ScatterNdOp (ConstantOp I32ElementsAttr: $indices), $params, $ignored), + (replaceWithValue $params), + [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>; + diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index af97931b2a3..804a391231a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -91,6 +91,9 @@ std::unique_ptr> CreateWhileOutlinePass(); // Verifies runtime constraints. std::unique_ptr> CreateRuntimeVerifyPass(); +// Creates raise custom ops pass, which legalize custom ops to TFL::CustomOp +std::unique_ptr> CreateRaiseCustomOpsPass(); + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 3d2ab662e6f..3be6246c0dd 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h" #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // The cmd line flag to turn on/off Tf.Text API fusion. @@ -56,9 +58,11 @@ namespace TFL { namespace { constexpr char kTFAPIImplements[] = "tf.api_implements"; -constexpr char kTfTextAPIPRefix[] = "tftext:"; +constexpr char kTFTextAPIPrefix[] = "tftext:"; constexpr char kTfNMSPadded[] = "non_max_suppression_padded_v2"; +using mlir::TF::FuncAttr; + // Abstracts the conversion of the embedded lookup composite function. class ConvertEmbeddedLookupFunc { public: @@ -161,7 +165,9 @@ class PrepareCompositeFunctionsPass explicit PrepareCompositeFunctionsPass() {} private: + // TODO(b/160915525): Consolidate FuncAttr and StringAttr into one. void ConvertTFImplements(FuncOp func, StringAttr attr); + void ConvertTFImplementsWithAttributes(FuncOp func, FuncAttr attr); void ConvertTFAPIImplements(FuncOp func, StringAttr attr, ModuleOp module); void runOnOperation() override; }; @@ -204,10 +210,23 @@ void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func, } } +void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes( + FuncOp func, FuncAttr attr) { + auto api_name = attr.GetName().getLeafReference(); + bool enable_fuse_tftext = + fuse_tftext_flag || IsTFTextRegistered(tensorflow::OpRegistry::Global()); + if (api_name.startswith(kTFTextAPIPrefix) && enable_fuse_tftext) { + if (failed(ConvertTFTextAPI(func, api_name, attr))) { + return signalPassFailure(); + } + } +} + LogicalResult CheckOutputConsumer( Operation* call_op, int expected_num_outputs, llvm::DenseSet expected_consumer_indices) { - if (call_op->getNumResults() != expected_num_outputs) return failure(); + const int num_results = call_op->getNumResults(); + if (num_results != expected_num_outputs) return failure(); for (int i = 0; i < expected_num_outputs; ++i) { auto it = expected_consumer_indices.find(i); @@ -220,21 +239,31 @@ LogicalResult CheckOutputConsumer( } LogicalResult CheckFusableKerasLstm(FuncOp lstm_func, ModuleOp module) { - bool check_failed = false; for (auto func : module.getOps()) { - func.walk([&](Operation* op) { - auto call_op = dyn_cast_or_null(op); - if (call_op && op->getAttrOfType("f").getRootReference() == - lstm_func.getName()) { + if (func == lstm_func) continue; + auto result = func.walk([&](CallOpInterface op) { + if (dyn_cast(op.resolveCallable()) == lstm_func) { // Keras LSTM have 5 outputs. - // We should make sure only the first or the second output are consumed. - if (failed(CheckOutputConsumer(call_op, 5, {0, 1}))) - check_failed = true; + // We should make sure only the first or the second output are + // consumed. + if (failed(CheckOutputConsumer(op.getOperation(), 5, {0, 1}))) + return WalkResult::interrupt(); } + return WalkResult::advance(); }); + + if (result.wasInterrupted()) return failure(); + } + + // We should know the batch size in advance for the lstm fusion. + // A good indicator of batch size is both cell state and input state have + // fixed shape. (indices 1 & 2). + for (int i = 1; i < 3; ++i) { + auto input = lstm_func.getArgument(i); + auto input_type = input.getType().dyn_cast_or_null(); + if (!input_type || !input_type.hasStaticShape()) return failure(); } - if (check_failed) return failure(); return success(); } @@ -256,26 +285,27 @@ void PrepareCompositeFunctionsPass::ConvertTFAPIImplements(FuncOp func, OpBuilder builder(func.getBody()); if (failed(ConvertKerasLSTMLayer(func, &builder))) return signalPassFailure(); - } else if (fuse_tftext_flag || - IsTfTextRegistered(tensorflow::OpRegistry::Global())) { - if (attr.getValue().startswith(kTfTextAPIPRefix)) { - if (failed(ConvertTFTextAPI(func, attr.getValue()))) { - return signalPassFailure(); - } - } } } void PrepareCompositeFunctionsPass::runOnOperation() { auto module = getOperation(); for (auto func : module.getOps()) { - // We have two kinds of implements: - // 1) tf._implements. - // 2) tf.api_implements. + // We have three kinds of implements: + // 1) tf._implements, with string attributes. + // 2) tf._implements, with proto attributes. + // 3) tf.api_implements. // We need to handle them separately. - auto tf_implements_attr = func.getAttrOfType(kTFImplements); + auto tf_implements_attr_str = func.getAttrOfType(kTFImplements); + if (tf_implements_attr_str) { + ConvertTFImplements(func, tf_implements_attr_str); + continue; + } + + auto tf_implements_attr = func.getAttrOfType(kTFImplements); if (tf_implements_attr) { - ConvertTFImplements(func, tf_implements_attr); + ConvertTFImplementsWithAttributes(func, tf_implements_attr); + continue; } auto tf_api_implements_attr = diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 6ee988496fa..62688937d7e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -41,7 +41,9 @@ limitations under the License. #include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project @@ -49,6 +51,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" @@ -57,7 +60,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" #define DEBUG_TYPE "tf-tfl-legalization" @@ -494,7 +499,8 @@ struct ConvertTFStridedSlice : public RewritePattern { original_input_type.getShape(); SmallVector new_shape; int index = 0; - while (index < original_input_shape.size() || new_axis_mask) { + const int original_input_rank = original_input_shape.size(); + while (index < original_input_rank || new_axis_mask) { if (new_axis_mask & 1) { new_shape.emplace_back(1); } else { @@ -696,6 +702,23 @@ LogicalResult ValidateOp(Operation *op) { return failure(has_illegal_ops); } +// Converts a set of TF2XLA ops into pure TF ops for future legalizations as +// TF2XLA ops aren't supported by later stages. +LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) { + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalOp(); + + OwningRewritePatternList patterns; + mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns); + TF::PopulateLegalizeHloToTfPatterns(&patterns, context); + + return applyPartialConversion(func, target, patterns); +} + void PrepareTFPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); @@ -711,6 +734,11 @@ void PrepareTFPass::runOnFunction() { return; } + if (failed(ConvertTf2XlaOps(func, ctx))) { + signalPassFailure(); + return; + } + // This pattern was intented to uses TFL QDQs to preserve the quantization // parameters from the TF Quant ops, thus this pattern should run with the // first `applyPatternsGreedily` method, which would otherwise removes the diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td index 22bcc563f7b..38c754ed08c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td @@ -33,7 +33,7 @@ def : Pat<(TFL_QuantizeOp (TFL_DequantizeOp $in), $qt), (replaceWithValue $in)>; // point constant. def : Pat<(TFL_DequantizeOp (TFL_QuantizeOp (ConstantOp F32ElementsAttr:$cst), $qt)), - (ConstantOp $cst)>; + (TFL_ConstOp $cst)>; // Quantize the value of a constant op if the quantization parameters have been // propagated to the output. diff --git a/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc b/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc new file mode 100644 index 00000000000..40cca526951 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFL { +namespace { +// This transformation pass takes an operation with unknown op properties and +// wrap it by a TFL::CustomTfOp. +struct RaiseCustomOpsPass + : public PassWrapper { + void runOnFunction() override; +}; + +void RaiseCustomOpsPass::runOnFunction() { + auto fn = getFunction(); + OpBuilder builder(fn.getContext()); + + llvm::SmallVector custom_ops; + for (Operation &op : fn.getOps()) { + // Skips the ops with known op property. + if (op.getAbstractOperation()) continue; + // Skips already imported ops that are imported as CustomTfOp. + if (op.getParentOfType()) continue; + if (llvm::isa(op) || llvm::isa(op)) + continue; + custom_ops.push_back(&op); + } + + for (auto *op : custom_ops) { + builder.setInsertionPoint(op); + auto custom_op = builder.create( + op->getLoc(), op->getResultTypes(), op->getOperands()); + Region region; + region.push_back(new Block); + + builder.setInsertionPointToEnd(®ion.front()); + Operation *inner_op = builder.clone(*op); + builder.create(op->getLoc(), inner_op->getResults()); + custom_op.body().takeBody(region); + + op->replaceAllUsesWith(custom_op); + op->erase(); + } +} +} // namespace + +// Creates an instance of the TensorFlow Lite dialect raise custom op pass. +std::unique_ptr> CreateRaiseCustomOpsPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tfl-raise-custom-ops", "Raise custom ops into tflite dialect."); + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index 106b0f9af83..56b38ec58d8 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -80,7 +80,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { // The basic block arguments correspond to values that are loop carried, while // all those post are loop independent. Initialize extern_values with while_op // not loop carried operands. - auto num_loop_carried = while_op.cond().front().getNumArguments(); + auto num_loop_carried = while_op.cond().getNumArguments(); auto not_carried_operands = while_op.getOperands().drop_front(num_loop_carried); extern_values.insert(not_carried_operands.begin(), @@ -124,8 +124,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { // Collect new types. SmallVector types; types.reserve(extra_operands.size() + while_op.getNumOperands()); - for (BlockArgument ba : while_op.cond().front().getArguments()) - types.push_back(ba.getType()); + for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type); for (Value operand : extern_values) types.push_back(operand.getType()); // Create outline function from region. Optional pass extra arguments through @@ -143,8 +142,7 @@ void WhileOutlinePass::OutlineWhile(WhileOp while_op) { type = FunctionType::get(types, result_types, &getContext()); } - auto outlined_func = builder.create(while_op.getLoc(), name, type, - ArrayRef{}); + auto outlined_func = builder.create(while_op.getLoc(), name, type); outlined_func.getBody().takeBody(region); Region& func_region = outlined_func.getBody(); diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index 22283d7eace..6b3ad78a830 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -53,6 +53,8 @@ mlir::Type ConvertElementType(tflite::TensorType type, mlir::Builder builder) { return builder.getIntegerType(16); case tflite::TensorType_COMPLEX64: return mlir::ComplexType::get(builder.getF32Type()); + case tflite::TensorType_COMPLEX128: + return mlir::ComplexType::get(builder.getF64Type()); case tflite::TensorType_INT8: return builder.getIntegerType(8); } @@ -64,6 +66,8 @@ tensorflow::DataType TflTypeToTfType(tflite::TensorType type) { return tensorflow::DT_BOOL; case tflite::TensorType_COMPLEX64: return tensorflow::DT_COMPLEX64; + case tflite::TensorType_COMPLEX128: + return tensorflow::DT_COMPLEX128; case tflite::TensorType_FLOAT16: return tensorflow::DT_HALF; case tflite::TensorType_FLOAT32: diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 2f876c68fb8..3a469dd7341 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -134,7 +134,7 @@ Value SliceRankedTensor(OpBuilder* builder, Value input, // the input tensor's dimensions, return 0-valued tensor of the requested // shape. ArrayRef input_shape = GetRankedTensorShape(input); - for (int i = 0; i < input_shape.size(); i++) { + for (int i = 0, end = input_shape.size(); i < end; i++) { if (begin_values[i] < 0 || (begin_values[i] + size_values[i] > input_shape[i])) { return CreateF32SplatConst(builder, size_shape, 0, location); diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc index 2ed0891dc59..96d22cb51e9 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h" +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" #include "llvm/ADT/SmallVector.h" @@ -28,6 +29,7 @@ limitations under the License. #include "mlir/IR/Identifier.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project @@ -43,30 +45,35 @@ namespace TFL { namespace { +constexpr char kNgrams[] = "tftext:Ngrams"; constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer"; -constexpr char kTFAPIImplements[] = "tf.api_implements"; +constexpr char kTFImplements[] = "tf._implements"; -inline OpaqueElementsAttr emptyCustomOption(OpBuilder* builder) { - std::string content = ""; +using mlir::TF::FuncAttr; +using mlir::TF::StringType; + +inline OpaqueElementsAttr CustomOption(OpBuilder* builder, + const std::string& content) { ShapedType type = RankedTensorType::get( {static_cast(content.size())}, builder->getIntegerType(8)); return OpaqueElementsAttr::get( - builder->getContext()->getRegisteredDialect("tfl"), type, content); + builder->getContext()->getRegisteredDialect("tfl"), type, + StringRef(content.data(), content.size())); } -inline RankedTensorType getInputType(mlir::FuncOp func, int idx) { - return func.getType() - .getInput(idx) - .dyn_cast_or_null(); +inline TensorType GetInputType(FuncOp func, int idx) { + return func.getType().getInput(idx).dyn_cast_or_null(); } -inline RankedTensorType getResultType(mlir::FuncOp func, int idx) { - return func.getType() - .getResult(idx) - .dyn_cast_or_null(); +inline TensorType GetResultType(FuncOp func, int idx) { + return func.getType().getResult(idx).dyn_cast_or_null(); } -LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) { +inline bool RankEquals(const TensorType& type, int rank) { + return type && type.hasRank() && type.getRank() == rank; +} + +LogicalResult VerifyWhitespaceTokenizer(FuncOp func) { // In the case of input tensor with 0 rank. // Whitespace tokenizer generates 1 output: // * String tensor for tokens. @@ -81,8 +88,8 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) { // * 1st output is the value of ragged tensor; // * 2nd output is the inner offset; // * 3rd output is the outer offset. - auto input_type = getInputType(func, 0); - if (!input_type || !input_type.getElementType().isa() || + auto input_type = GetInputType(func, 0); + if (!input_type || !input_type.getElementType().isa() || !input_type.hasRank()) { return func.emitError() << "Input should be a string tensor"; } @@ -98,21 +105,21 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) { << "output(s) when input has rank " << input_type.getRank(); } - auto value_type = getResultType(func, 0); - if (!value_type || !value_type.hasRank() || value_type.getRank() != 1 || - !value_type.getElementType().isa()) { + auto value_type = GetResultType(func, 0); + if (!RankEquals(value_type, 1) || + !value_type.getElementType().isa()) { return func.emitError() << "1st output should be string tensor"; } if (func.getNumResults() > 1) { - auto offset_type = getResultType(func, 1); - if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 || + auto offset_type = GetResultType(func, 1); + if (!RankEquals(offset_type, 1) || !offset_type.getElementType().isInteger(64)) { return func.emitError() << "2nd output should be int64 tensor"; } } if (func.getNumResults() > 2) { - auto offset_type = getResultType(func, 2); - if (!offset_type || !offset_type.hasRank() || offset_type.getRank() != 1 || + auto offset_type = GetResultType(func, 2); + if (!RankEquals(offset_type, 1) || !offset_type.getElementType().isInteger(64)) { return func.emitError() << "3rd output should be int64 tensor"; } @@ -121,36 +128,168 @@ LogicalResult VerifyWhitespaceTokenizer(mlir::FuncOp func) { return success(); } -LogicalResult ConvertWhitespaceTokenizer(mlir::FuncOp func, - llvm::StringRef api) { +LogicalResult ConvertWhitespaceTokenizer(FuncOp func, llvm::StringRef api, + FuncAttr attr) { func.eraseBody(); func.addEntryBlock(); - func.setAttr(kTFAPIImplements, StringAttr::get(api, func.getContext())); - Value text = func.getArgument(0); + func.setAttr(kTFImplements, attr); OpBuilder builder(func.getBody()); - - auto op = builder.create( - func.getLoc(), func.getType().getResults(), ValueRange(text), api, - emptyCustomOption(&builder)); - builder.create(func.getLoc(), op.getResults()); + std::string empty_option_buffer; + auto op = builder.create( + func.getLoc(), func.getType().getResults(), func.getArguments(), api, + CustomOption(&builder, empty_option_buffer)); + builder.create(func.getLoc(), op.getResults()); return success(); } + +LogicalResult VerifyNgrams(FuncOp func) { + // The inputs and outputs should be the same: + // * A string tensor for tokens/ragged tensor values. + // * Zero or more row_split tensors. + constexpr int kValues = 0; + constexpr int kRowSplits = 1; + + if (func.getType().getInputs().size() != func.getType().getResults().size()) { + return func.emitError() << "Mismatched number of inputs and outputs."; + } + + int row_splits = func.getType().getInputs().size() - kRowSplits; + if (row_splits == 0) { + auto input_values = GetInputType(func, kValues); + if (!input_values || !input_values.getElementType().isa()) { + return func.emitError() + << "Input " << kValues << " should be a string tensor"; + } + auto output_values = GetResultType(func, kValues); + if (!output_values || !output_values.getElementType().isa()) { + return func.emitError() + << "Output " << kValues << " should be a string tensor"; + } + + if (input_values.hasRank() && output_values.hasRank() && + input_values.getRank() != output_values.getRank()) { + return func.emitError() << "Input " << kValues << " and output " + << kValues << " should have the same rank"; + } + } else { + auto input_values = GetInputType(func, kValues); + if (!RankEquals(input_values, 1) || + !input_values.getElementType().isa()) { + return func.emitError() + << "Input " << kValues << " should be a 1D string tensor"; + } + auto output_values = GetResultType(func, kValues); + if (!RankEquals(output_values, 1) || + !output_values.getElementType().isa()) { + return func.emitError() + << "Output " << kValues << " should be a 1D string tensor"; + } + + for (int i = 0; i < row_splits; ++i) { + const int row_index = i + kRowSplits; + auto input_row_splits = GetInputType(func, row_index); + if (!RankEquals(input_row_splits, 1) || + !input_row_splits.getElementType().isInteger(64)) { + return func.emitError() + << "Input " << row_index << " should be a 1D int64 tensor"; + } + auto output_row_splits = GetResultType(func, row_index); + if (!RankEquals(output_row_splits, 1) || + !output_row_splits.getElementType().isInteger(64)) { + return func.emitError() + << "Output " << row_index << " should be a 1D int64 tensor"; + } + } + } + + return success(); +} + +LogicalResult CreateNgramsCustomOption(FuncOp func, DictionaryAttr attrs, + std::string& custom_option_buffer) { + flexbuffers::Builder fbb; + size_t start_map = fbb.StartMap(); + + auto width = attrs.get("width").dyn_cast_or_null(); + if (!width) { + return func.emitError() << "'width' attribute is not set or not an integer"; + } + fbb.Int("width", width.getInt()); + + auto string_separator = + attrs.get("string_separator").dyn_cast_or_null(); + if (!string_separator) { + return func.emitError() + << "'string_separator' attribute is not set or not a string"; + } + // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers + // strings expect NUL terminated strings. + std::string string_separator_str(string_separator.getValue().data(), + string_separator.getValue().size()); + fbb.String("string_separator", string_separator_str); + + auto axis = attrs.get("axis").dyn_cast_or_null(); + if (!axis) { + return func.emitError() << "'axis' attribute is not set or not an integer"; + } + fbb.Int("axis", axis.getInt()); + + auto reduction_type = + attrs.get("reduction_type").dyn_cast_or_null(); + if (!reduction_type) { + return func.emitError() + << "'reduction_type' attribute is not set or not a string"; + } + // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers + // strings expect NUL terminated strings. + std::string reduction_type_str(reduction_type.getValue().data(), + reduction_type.getValue().size()); + fbb.String("reduction_type", reduction_type_str); + + fbb.EndMap(start_map); + fbb.Finish(); + custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end()); + return success(); +} + +LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) { + func.eraseBody(); + func.addEntryBlock(); + func.setAttr(kTFImplements, attr); + OpBuilder builder(func.getBody()); + std::string custom_option_buffer; + if (failed(CreateNgramsCustomOption(func, attr.GetAttrs(), + custom_option_buffer))) { + return failure(); + } + auto op = builder.create( + func.getLoc(), func.getType().getResults(), func.getArguments(), api, + CustomOption(&builder, custom_option_buffer)); + builder.create(func.getLoc(), op.getResults()); + return success(); +} + } // namespace -LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api) { +LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api, + FuncAttr attr) { if (api.str() == kWhitespaceTokenizer) { if (succeeded(VerifyWhitespaceTokenizer(func))) { - return ConvertWhitespaceTokenizer(func, api); + return ConvertWhitespaceTokenizer(func, api, attr); + } + } else if (api.str() == kNgrams) { + if (succeeded(VerifyNgrams(func))) { + return ConvertNgrams(func, api, attr); } } return failure(); } -bool IsTfTextRegistered(const tensorflow::OpRegistry* op_registery) { - const std::vector kTfTextOps = { +bool IsTFTextRegistered(const tensorflow::OpRegistry* op_registery) { + const std::vector kTFTextOps = { "WhitespaceTokenizeWithOffsets", }; - for (const auto& iter : kTfTextOps) { + for (const auto& iter : kTFTextOps) { if (op_registery->LookUp(iter)) { return true; } diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.h b/tensorflow/compiler/mlir/lite/utils/tftext_utils.h index c52ee019d8d..55e4680c3dd 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.h @@ -27,14 +27,18 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/core/framework/op.h" namespace mlir { namespace TFL { -LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api); +// Fuse TF.Text APIs annotated by tf.function to a TFLite custom op. +LogicalResult ConvertTFTextAPI(mlir::FuncOp func, llvm::StringRef api, + mlir::TF::FuncAttr attr); -bool IsTfTextRegistered(const tensorflow::OpRegistry* op_registery); +// Check if TF.Text Tensorflow ops are registered. +bool IsTFTextRegistered(const tensorflow::OpRegistry* op_registery); } // end namespace TFL } // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc index 7d29264aaae..9bcfa89c544 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc @@ -41,13 +41,13 @@ void Register(const std::string& op_name, OpRegistry* registry) { TEST(TfTextUtilsTest, TestTfTextRegistered) { std::unique_ptr registry(new OpRegistry); Register("WhitespaceTokenizeWithOffsets", registry.get()); - EXPECT_TRUE(IsTfTextRegistered(registry.get())); + EXPECT_TRUE(IsTFTextRegistered(registry.get())); } TEST(TfTextUtilsTest, TestTfTextNotRegistered) { std::unique_ptr registry(new OpRegistry); Register("Test", registry.get()); - EXPECT_FALSE(IsTfTextRegistered(registry.get())); + EXPECT_FALSE(IsTFTextRegistered(registry.get())); } } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 67002aa65bf..8be6facce38 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -115,13 +115,15 @@ Status MlirFunctionOptimizationPass::Run( }); if (!is_enabled) { - VLOG(0) << "None of the MLIR optimization passes are enabled " - << "(registered " << registry_->passes().size() << ")"; + LOG_FIRST_N(INFO, 1) + << "None of the MLIR optimization passes are enabled " + << "(registered " << registry_->passes().size() << ")"; return Status::OK(); } - VLOG(0) << "Running MLIR Graph Optimization Passes " - << "(registered " << registry_->passes().size() << " passes)"; + LOG_FIRST_N(INFO, 1) << "Running MLIR Graph Optimization Passes " + << "(registered " << registry_->passes().size() + << " passes)"; GraphDebugInfo debug_info; RegisterDialects(); @@ -130,6 +132,12 @@ Status MlirFunctionOptimizationPass::Run( import_config.graph_as_function = true; import_config.control_outputs = *control_ret_node_names; import_config.upgrade_legacy = true; + // Disable shape inference during import as some TensorFlow op fails during + // shape inference with dynamic shaped operands. This in turn causes the + // import to fail. Shape inference during import is going to be removed and + // the shape inference pass is run early in the pass pipeline, shape inference + // during import is not necessary. + import_config.enable_shape_inference = false; TF_ASSIGN_OR_RETURN(auto module_ref, ConvertGraphToMlir(**graph, debug_info, *flib_def, import_config, &context)); @@ -187,13 +195,15 @@ Status MlirV1CompatGraphOptimizationPass::Run( }); if (!is_enabled) { - VLOG(0) << "None of the MLIR optimization passes are enabled " - << "(registered" << registry_->passes().size() << " passes)"; + LOG_FIRST_N(INFO, 1) + << "None of the MLIR optimization passes are enabled " + << "(registered " << registry_->passes().size() << " passes)"; return Status::OK(); } - VLOG(0) << "Running MLIR Graph Optimization V1 Compat Passes " - << "(registered" << registry_->passes().size() << " passes)"; + LOG_FIRST_N(INFO, 1) << "Running MLIR Graph Optimization V1 Compat Passes " + << "(registered " << registry_->passes().size() + << " passes)"; GraphDebugInfo debug_info; RegisterDialects(); diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index e3158f21cb2..45c8dce8422 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -73,7 +73,8 @@ tool_names = [ 'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', - 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir' + 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir', + 'kernel-gen-opt', 'xla-thunks-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/runlit.site.cfg.py b/tensorflow/compiler/mlir/runlit.site.cfg.py index 82175d7f680..b4d3e6185a6 100644 --- a/tensorflow/compiler/mlir/runlit.site.cfg.py +++ b/tensorflow/compiler/mlir/runlit.site.cfg.py @@ -47,6 +47,7 @@ mlir_tf_tools_dirs = [ 'tensorflow/compiler/mlir/tensorflow', 'tensorflow/compiler/mlir/tfjs', 'tensorflow/compiler/mlir/xla', + 'tensorflow/compiler/mlir/tools/kernel_gen', 'tensorflow/compiler/aot', 'tensorflow/compiler/xla/service/mlir_gpu', 'tensorflow/compiler/xla/service/gpu/tests', diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 14d7faecdca..d2e57f72774 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -88,6 +88,7 @@ gentbl( cc_library( name = "tensorflow_op_interfaces", srcs = [ + "ir/tf_op_interfaces.cc", "ir/tf_op_interfaces.cc.inc", "ir/tf_op_interfaces.h.inc", "ir/tf_verifiers.cc", @@ -105,15 +106,67 @@ cc_library( ) gentbl( - name = "tensorflow_ops_inc_gen", + name = "tensorflow_all_ops_inc_gen", tbl_outs = [ ( "-gen-op-decls", - "ir/tf_ops.h.inc", + "ir/tf_all_ops.h.inc", ), ( "-gen-op-defs", - "ir/tf_ops.cc.inc", + "ir/tf_all_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/tf_ops.td", + td_srcs = [ + ":tensorflow_ops_td_files", + ], +) + +# We only shard tf_op on name for build performance reasons. +tf_ops_category_list = [ + { + "name": "ops_a_m", + "include": "tf.[A-M].*$$", + }, + { + "name": "ops_n_z", + "include": "tf.[N-Z].*$$", + }, +] + +[[ + gentbl( + name = "tensorflow_" + target["name"] + "_inc_gen", + tbl_outs = [ + ( + "-gen-op-decls -op-include-regex='" + target["include"] + "'", + "ir/tf_" + target["name"] + ".h.inc", + ), + ( + "-gen-op-defs -op-include-regex='" + target["include"] + "'", + "ir/tf_" + target["name"] + ".cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/tf_ops.td", + td_srcs = [ + ":tensorflow_ops_td_files", + ], + ), +] for target in tf_ops_category_list] + +gentbl( + name = "tensorflow_remaining_ops_inc_gen", + tbl_outs = [ + ( + "-gen-op-decls -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ", + "ir/tf_remaining_ops.h.inc", + ), + ( + "-gen-op-defs -op-exclude-regex='" + "|".join([target["include"] for target in tf_ops_category_list]) + "' ", + "ir/tf_remaining_ops.cc.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", @@ -179,7 +232,7 @@ gentbl( name = "tensorflow_device_ops_inc_gen", tbl_outs = [ ( - "-gen-op-decls", + "-gen-op-decls ", "ir/tf_device.h.inc", ), ( @@ -280,28 +333,72 @@ cc_library( deps = [ ":tensorflow_types", "@llvm-project//mlir:IR", + "@llvm-project//mlir:SideEffects", "@llvm-project//mlir:Support", ], ) +[[ + cc_library( + name = "tensorflow_" + target["name"], + srcs = [ + "ir/tf_ops.h", + "ir/tf_remaining_ops.h", + "ir/tf_" + target["name"] + ".cc", + "ir/tf_" + target["name"] + ".cc.inc", + ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], + hdrs = [ + ], + textual_hdrs = [ + "ir/tf_all_ops.h.inc", + "ir/tf_ops_helpers.inc", + "ir/tf_remaining_ops.h.inc", + ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], + deps = [ + ":tensorflow_attributes", + ":tensorflow_canonicalize_inc_gen", + ":tensorflow_op_interfaces", + ":tensorflow_op_interfaces_inc_gen", + ":tensorflow_side_effects", + ":tensorflow_structs", + ":tensorflow_traits", + ":tensorflow_types", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LoopLikeInterface", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ] + [":tensorflow_" + target["name"] + "_inc_gen"], + ), +] for target in tf_ops_category_list] + cc_library( - name = "tensorflow_ops", + name = "tensorflow_remaining_ops", srcs = [ - "ir/tf_ops.cc", - "ir/tf_ops.cc.inc", "ir/tf_ops.h", - ], + "ir/tf_remaining_ops.h", + "ir/tf_remaining_ops.cc", + ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], hdrs = [ ], textual_hdrs = [ - "ir/tf_ops.h.inc", - ], + "ir/tf_all_ops.h.inc", + "ir/tf_ops_helpers.inc", + "ir/tf_remaining_ops.h.inc", + ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], deps = [ ":tensorflow_attributes", ":tensorflow_canonicalize_inc_gen", ":tensorflow_op_interfaces", ":tensorflow_op_interfaces_inc_gen", - ":tensorflow_ops_inc_gen", + ":tensorflow_remaining_ops_inc_gen", ":tensorflow_side_effects", ":tensorflow_structs", ":tensorflow_traits", @@ -321,6 +418,43 @@ cc_library( ], ) +cc_library( + name = "tensorflow_ops", + srcs = [ + "ir/tf_ops.cc", + "ir/tf_ops.h", + ], + textual_hdrs = [ + "ir/tf_all_ops.h.inc", + "ir/tf_remaining_ops.h", + ] + ["ir/tf_" + target["name"] + ".h" for target in tf_ops_category_list], + deps = [ + ":tensorflow_all_ops_inc_gen", + ":tensorflow_remaining_ops_inc_gen", + ":tensorflow_attributes", + ":tensorflow_canonicalize_inc_gen", + ":tensorflow_op_interfaces", + ":tensorflow_op_interfaces_inc_gen", + ":tensorflow_side_effects", + ":tensorflow_structs", + ":tensorflow_traits", + ":tensorflow_types", + ":tensorflow_remaining_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:LoopLikeInterface", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:SideEffects", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ] + [":tensorflow_" + target["name"] for target in tf_ops_category_list], +) + cc_library( name = "tensorflow_structs", srcs = [ @@ -393,12 +527,14 @@ cc_library( includes = ["include"], deps = [ ":error_util", + ":tensorflow_all_ops_inc_gen", ":tensorflow_attributes", ":tensorflow_canonicalize_inc_gen", ":tensorflow_device_ops_inc_gen", ":tensorflow_executor_inc_gen", ":tensorflow_op_interfaces", ":tensorflow_ops", + ":tensorflow_side_effects", ":tensorflow_structs", ":tensorflow_traits", ":tensorflow_types", @@ -540,6 +676,7 @@ cc_library( cc_library( name = "tf_saved_model_passes", srcs = [ + "transforms/deduplicate_bound_input_bindings.cc", "transforms/freeze_global_tensors.cc", "transforms/lift_variables_pass.cc", "transforms/optimize_global_tensors.cc", @@ -567,6 +704,30 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "tensorflow_analysis", + srcs = [ + "analysis/per_function_aggregate_analysis.h", + "analysis/resource_alias_analysis.cc", + "analysis/side_effect_analysis.cc", + ], + hdrs = [ + "analysis/resource_alias_analysis.h", + "analysis/side_effect_analysis.h", + ], + deps = [ + ":tensorflow", + ":tensorflow_types", + "//tensorflow/compiler/tf2xla:resource_operation_table", + "//tensorflow/core:framework", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "tensorflow_passes", srcs = [ @@ -592,11 +753,15 @@ cc_library( "transforms/generated_optimize.inc", "transforms/gpu_fusion.cc", "transforms/graph_pruning.cc", + "transforms/guarantee_all_funcs_one_use.cc", + "transforms/init_text_file_to_import.cc", "transforms/launch_to_device_attribute.cc", "transforms/layout_optimization.cc", + "transforms/mark_ops_for_outside_compilation.cc", "transforms/materialize_mlir_passthrough_op.cc", "transforms/optimize.cc", "transforms/parallel_execute_to_islands.cc", + "transforms/parallelize_embedding_params_ops_pass.cc", "transforms/promote_resources_to_args.cc", "transforms/readonly_references_to_resources.cc", "transforms/region_control_flow_to_functional.cc", @@ -611,6 +776,7 @@ cc_library( "transforms/stack_ops_decomposition.cc", "transforms/tensor_array_ops_decomposition.cc", "transforms/tensor_list_ops_decomposition.cc", + "transforms/test_resource_alias_analysis.cc", "transforms/test_side_effect_analysis.cc", "transforms/tf_data_optimization_pass.cc", "transforms/tf_device_assignment.cc", @@ -632,6 +798,7 @@ cc_library( "translate/tf_functional_to_executor.cc", ], hdrs = [ + "transforms/attribute_utils.h", "transforms/batchmatmul_to_einsum.h", "transforms/bridge.h", "transforms/collection_ops_util.h", @@ -650,8 +817,8 @@ cc_library( ":error_util", ":export_tf_dialect_op", ":mangling_util", - ":side_effect_analysis", ":tensorflow", + ":tensorflow_analysis", ":tensorflow_optimize_inc_gen", ":tensorflow_types", ":tf_data_optimization", @@ -661,6 +828,8 @@ cc_library( ":xla_sharding_util", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite:validators", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_proto_cc", "//tensorflow/compiler/xla/client:sharding_builder", @@ -671,6 +840,7 @@ cc_library( "//tensorflow/core/platform:random", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", + "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", @@ -690,6 +860,7 @@ cc_library( cc_library( name = "tensorflow_test_passes", srcs = [ + "transforms/init_text_file_to_import_test_pass.cc", "transforms/lift_variables_test_pass.cc", "transforms/lower_tf_pass.cc", ], @@ -705,8 +876,10 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "//tensorflow/core/platform:threadpool_options", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", ], alwayslink = 1, @@ -1190,11 +1363,13 @@ cc_library( ":mlir_roundtrip_flags", "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler/utils:transitive_fanin", "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1315,6 +1490,7 @@ COMPILE_MLIR_UTIL_DEPS = [ ":mlir_roundtrip_flags", ":tensorflow", ":tensorflow_dialect_registration", + ":tensorflow_types", ":tensorflow_passes", ":translate_utils", "@com_google_absl//absl/types:optional", @@ -1333,10 +1509,13 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_argument", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/platform:logging", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", @@ -1373,6 +1552,9 @@ tf_cc_test( srcs = ["utils/compile_mlir_util_test.cc"], deps = [ ":compile_mlir_util", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/cc:scope", "//tensorflow/compiler/jit", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", @@ -1500,6 +1682,7 @@ cc_library( ":tensorflow", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -1602,22 +1785,6 @@ cc_library( ], ) -cc_library( - name = "side_effect_analysis", - srcs = ["analysis/side_effect_analysis.cc"], - hdrs = ["analysis/side_effect_analysis.h"], - deps = [ - ":tensorflow", - ":tensorflow_types", - "//tensorflow/compiler/tf2xla:resource_operation_table", - "//tensorflow/core:framework", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], -) - cc_library( name = "xla_sharding_util", srcs = [ diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h new file mode 100644 index 00000000000..da7a2bd9b5c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h @@ -0,0 +1,76 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_PER_FUNCTION_AGGREGATE_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_PER_FUNCTION_AGGREGATE_ANALYSIS_H_ + +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace TF { +namespace detail { + +// This template defines an aggregate analysis base class, which analyzes a +// module but the analysis info is stored per function. +template +class PerFunctionAggregateAnalysis { + public: + using Info = InfoT; + + // Returns the analysis info for the given function. + const Info& GetAnalysisForFunc(FuncOp func) const { + auto it = info_map_.find(func); + assert(it != info_map_.end()); + return it->second; + } + + protected: + llvm::SmallDenseMap info_map_; +}; + +} // namespace detail + +// Base CRTP class to help write passes that are consumes a per-function +// aggregate analysis and operate on all non-extern functions (similar to a +// FunctionPass, but with no concurrency between functions). The derived classes +// need to provide a runOnFunction() method that accepts the function and the +// analysis information for that function. +template +class PerFunctionAggregateAnalysisConsumerPass + : public PassWrapper< + PerFunctionAggregateAnalysisConsumerPass, + OperationPass> { + void runOnOperation() override { + ModuleOp op = this->getOperation(); + DerivedT& derived = *static_cast(this); + auto& analysis = this->template getAnalysis(); + + for (auto func : op.getOps()) + if (!func.isExternal()) + derived.runOnFunction(func, analysis.GetAnalysisForFunc(func)); + } +}; + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_PER_FUNCTION_AGGREGATE_ANALYSIS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc new file mode 100644 index 00000000000..256217b6542 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc @@ -0,0 +1,507 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "tensorflow/core/framework/resource_mgr.h" + +namespace mlir { +namespace TF { +namespace detail { + +//===----------------------------------------------------------------------===// +// BacktrackAnalysisInfo +//===----------------------------------------------------------------------===// +// Class to hold backtrack analysis for a results of a region. Backtrack +// analysis will trace back the definition of return values of regions through +// pass-through operations, so that the return value of the region will have the +// same value as the backtracked value. +class BacktrackAnalysisInfo { + public: + // Initializes the backtrack analysis for the given region. + explicit BacktrackAnalysisInfo(Region& region, + detail::BacktrackAnalysis& backtrack_analysis); + + BacktrackAnalysisInfo(BacktrackAnalysisInfo&&) = default; + + // Returns the value to which the given result number of the region can be + // backtracked to. + Value GetValue(int result_index) const { + return backtracked_values_[result_index]; + } + + // Returns the argument index of the region to which the given result number + // can backtracked to. Such results will be called "function passthrough". If + // the result cannot be backtracked to a region argument, returns llvm::None. + llvm::Optional GetArg(int result_index) const { + if (auto arg = GetValue(result_index).dyn_cast()) + if (arg.getParentBlock() == ®ion_->front()) return arg.getArgNumber(); + return llvm::None; + } + + private: + friend class detail::BacktrackAnalysis; + + // Region for which this object holds the analysis info. + Region* region_; + + // Backtracked values indexed by the result number. + llvm::SmallVector backtracked_values_; +}; + +//===----------------------------------------------------------------------===// +// BacktrackAnalysis +//===----------------------------------------------------------------------===// +// Holds backtrack analysis for all functions and regions within a module. +class BacktrackAnalysis { + public: + using InfoT = BacktrackAnalysisInfo; + + // Constructs the analysis by analyzing the given module. + explicit BacktrackAnalysis(ModuleOp module); + + // Returns backtracking analysis for the given region. + const InfoT& GetAnalysisForRegion(Region& region) const { + auto it = info_map_.find(®ion); + assert(it != info_map_.end()); + return it->second; + } + + // Returns backtracking analysis for the given function. + const InfoT& GetAnalysisForFunc(FuncOp func) const { + return GetAnalysisForRegion(func.getBody()); + } + + // Backtracks the given value. + Value BacktrackValue(Value value); + + private: + // Returns the analysis for the given region (analyzing the region if it has + // not yet been analyzed). + const InfoT& GetOrCreateAnalysis(Region& region) { + auto it = info_map_.find(®ion); + if (it == info_map_.end()) { + // Note: Keep object construction and insertion separate. If we use + // emplace() to construct and insert in a single shot, when analyzing + // this region, calls to BacktrackValue() may end up inserting additional + // entries in the map, causing the underlying storage to be moved. This + // would also include this pertially constructed object that we have just + // inserted into the map and are constructing it. To avoid this issue, + // construct the analysis object separately and then insert it into the + // map. + InfoT info(region, *this); + info_map_.insert({®ion, std::move(info)}); + } + + return GetAnalysisForRegion(region); + } + + private: + llvm::SmallDenseMap info_map_; +}; + +// Analyzes all regions attached to all operations in the module. +BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) { + module.walk([this](Operation* op) { + for (Region& region : op->getRegions()) GetOrCreateAnalysis(region); + }); +} + +// Backtracks the definition of `value` looking through passthrough ops. +// Returns a non-null value and can return `value` if backtracking is not +// possible. +Value BacktrackAnalysis::BacktrackValue(Value value) { + while (Operation* op = value.getDefiningOp()) { + int res_index = value.cast().getResultNumber(); + if (auto graph = dyn_cast(op)) { + value = graph.GetFetch().getOperand(res_index); + } else if (auto island = dyn_cast(op)) { + // Control output is generated by the IslandOp, not the yield in + // in the Island body. + if (value == island.control()) break; + value = island.GetYield().getOperand(res_index); + } else if (isa(op)) { + value = op->getOperand(res_index); + } else { + break; + } + } + return value; +} + +// Analyze the region. +BacktrackAnalysisInfo::BacktrackAnalysisInfo( + Region& region, detail::BacktrackAnalysis& backtrack_analysis) + : region_(®ion) { + if (region.empty()) return; + + assert(llvm::hasSingleElement(region.getBlocks())); + auto results = region.front().getTerminator()->getOperands(); + if (results.empty()) return; + + backtracked_values_.reserve(results.size()); + for (auto result : results) + backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result)); +} + +namespace { + +//===----------------------------------------------------------------------===// +// ResourceAliasAnalysisInfo helper functions. +//===----------------------------------------------------------------------===// + +constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; + +// Returns if a VarHandleOp is anonymous, which means it always creates a new +// variable. +bool IsResourceHandleAnonymous(VarHandleOp handle) { + return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME; +} + +// Returns a string unique identifier for a non-anonymous VarHandleOp. +std::string GetVarHandleStringId(VarHandleOp handle) { + auto device = handle.getAttrOfType("device"); + return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(), + "/", device ? device.getValue().str() : std::string("")); +} + +// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always +// creates a new ID; otherwise, tries to reuse the existing ID for the +// referenced variable if it exists, or creates a new one if not. +int64_t GetOrCreateIdForVarHandle(VarHandleOp handle, int64_t* next_id, + llvm::StringMap* name_id_map) { + // Always create a new ID for anonymous handle. + if (IsResourceHandleAnonymous(handle)) return (*next_id)++; + + auto name = GetVarHandleStringId(handle); + auto emplace_res = name_id_map->try_emplace(name, *next_id); + // New ID created, increment next_id. + if (emplace_res.second) ++(*next_id); + return emplace_res.first->second; +} + +} // namespace + +//===----------------------------------------------------------------------===// +// ResourceAliasAnalysisInfo +//===----------------------------------------------------------------------===// + +// Constructs the analysis info by analyzing the given function. +ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo( + FuncOp func_op, const BacktrackAnalysis& backtrack_analysis) { + // This function populates resource_value_to_ids_ and id_to_resource_values_. + + int64_t next_unique_id = 0; + + // Helper to assign new unique id for all resources in the given list of + // values. + auto assign_unique_id_to_all = [&](ValueRange values) { + for (Value value : filter_resources(values)) { + AddValueUniqueIDMapping(value, next_unique_id++); + } + }; + + // Helper to assign new unknown id for all resources in the given list of + // values. + auto assign_unknown_id_to_all = [&](ValueRange values) { + for (Value value : filter_resources(values)) { + AddValueUniqueIDMapping(value, kUnknownResourceId); + } + }; + + // If the "tf.resource_arg_unique_id" argument attributes are present for + // resource-type arguments, respect them when choosing IDs; otherwise, they + // must not alias. + const bool has_arg_unique_id_attrs = + llvm::any_of(func_op.getArguments(), [&](const BlockArgument& arg) { + return func_op.getArgAttr(arg.getArgNumber(), kResourceArgUniqueIdAttr); + }); + // Maps the kResourceArgUniqueIdAttr attribute value to the internal integer + // ID used by this pass. + if (has_arg_unique_id_attrs) { + llvm::SmallDenseMap attr_id_to_internal_id; + for (auto arg : filter_resources(func_op.getArguments())) { + auto id_attr = func_op.getArgAttrOfType( + arg.getArgNumber(), kResourceArgUniqueIdAttr); + assert(id_attr && + "tf.resource_arg_unique_id attribute should exist on either " + "none or all arguments."); + auto emplace_res = attr_id_to_internal_id.try_emplace(id_attr.getInt(), + next_unique_id++); + AddValueUniqueIDMapping(arg, emplace_res.first->getSecond()); + } + } else { + assign_unique_id_to_all(func_op.getArguments()); + } + + // Since this analysis is neither inter-procedural nor inter-regional, + // each region attached to Op's within a function is analyzed independently. + // Seed this analysis for each such region by mapping all resource arguments + // for such regions to a new unique-id. This is required because walk() walks + // the attached regions first before visiting the op, so there is no + // opportunity during the walk to seed region arguments. Also note that walk + // eventually also visits the Op on which the walk() is called, so make sure + // we do not overwrite the function argument mapping here. + func_op.walk([&](Operation* op) { + if (op == func_op) return; + for (Region& region : op->getRegions()) { + assign_unique_id_to_all(region.getArguments()); + } + }); + + llvm::StringMap var_handle_name_id_map; + func_op.walk([&](Operation* op) { + if (auto var_handle = dyn_cast(op)) { + AddValueUniqueIDMapping( + var_handle.resource(), + GetOrCreateIdForVarHandle(var_handle, &next_unique_id, + &var_handle_name_id_map)); + } else if (llvm::isa(op)) { + for (auto result : filter_resources(op->getResults())) + PropagateInputToOutput(op->getOperand(result.getResultNumber()), + result); + } else if (auto while_op = dyn_cast(op)) { + AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc( + while_op.body_func())); + } else if (auto while_region = dyn_cast(op)) { + AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion( + while_region.body())); + } else if (auto if_op = dyn_cast(op)) { + const auto& then_info = + backtrack_analysis.GetAnalysisForFunc(if_op.then_func()); + const auto& else_info = + backtrack_analysis.GetAnalysisForFunc(if_op.else_func()); + // If a result is a passthrough of both branches' inputs, merge the + // resource IDs of corresponding operands for the two inputs. + for (auto result : filter_resources(if_op.getResults())) { + auto passthrough_then_arg = then_info.GetArg(result.getResultNumber()); + auto passthrough_else_arg = else_info.GetArg(result.getResultNumber()); + if (passthrough_then_arg && passthrough_else_arg) { + Value then_operand = if_op.input()[passthrough_then_arg.getValue()]; + Value else_operand = if_op.input()[passthrough_else_arg.getValue()]; + PropagateInputToOutput(then_operand, result); + PropagateInputToOutput(else_operand, result); + } else { + AddValueUniqueIDMapping(result, kUnknownResourceId); + } + } + } else if (auto if_region = dyn_cast(op)) { + const auto& then_info = + backtrack_analysis.GetAnalysisForRegion(if_region.then_branch()); + const auto& else_info = + backtrack_analysis.GetAnalysisForRegion(if_region.else_branch()); + for (auto result : filter_resources(if_region.getResults())) { + Value then_result = then_info.GetValue(result.getResultNumber()); + Value else_result = else_info.GetValue(result.getResultNumber()); + // For IfRegion, the walk would have visited the else and then regions + // before visiting the IfRegion op. Backtracking of the then and else + // results will either give a value computed within these regions, + // or a region capture. If its a region capture, computed before this + // IfRegion, it will have been visited earlier and a mapping would + // exist for that value. If its computed within the region, then again + // a mapping would exist. + PropagateInputToOutput(then_result, result); + PropagateInputToOutput(else_result, result); + } + } else if (auto call = dyn_cast(op)) { + FuncOp func = dyn_cast(call.resolveCallable()); + if (!func) { + assign_unknown_id_to_all(op->getResults()); + return WalkResult::advance(); + } + const auto& func_info = backtrack_analysis.GetAnalysisForFunc(func); + for (auto result : filter_resources(op->getResults())) { + auto passthrough_arg = func_info.GetArg(result.getResultNumber()); + if (passthrough_arg) { + PropagateInputToOutput( + call.getArgOperands()[passthrough_arg.getValue()], result); + } else { + AddValueUniqueIDMapping(result, kUnknownResourceId); + } + } + } else { + assign_unknown_id_to_all(op->getResults()); + } + return WalkResult::advance(); + }); +} + +// Propagates the resource ID's from an input operand to a result. Returns true +// if the mapping changed. +bool ResourceAliasAnalysisInfo::PropagateInputToOutput(const Value& operand, + const OpResult& result) { + auto operand_it = resource_value_to_ids_.find(operand); + assert(operand_it != resource_value_to_ids_.end() && + "A resource-type output does not have the corresponding " + "resource-type input."); + bool change = false; + for (int64_t id : operand_it->second) + change = AddValueUniqueIDMapping(result, id) || change; + return change; +} + +// Analyzes while loops to compute resourceIDs for the loop results. +// +// (1) The base case for the analysis is that if the loop body does not execute +// at all, the resource IDs for each result is the same as the resource IDs +// of the corresponding input. +// (2) If the loop does execute one or more times, then we need to account for +// data flow through the body of the while loop. If result #r is the same +// as arg #a of the loop body (pass through argument), then we can reason +// further, else if the result is not a passthrough, we mark it as unknown. +// (3) For passthrough results, if result #r is the same as arg #a of the loop +// body, after one iteration, result #r = arg #a, so we need to also +// propagate arg #a to result #r. After another iteration, arg #a of the +// loop body will be result #a of the previous iteration. So then we need +// propagate from result #a to result #r. Generalizing, the resource ID +// propagation (for results which are passthrough) looks like: +// +// for r in (0, num_results) : result[r] = arg[r]; +// repeat till no change { +// a = passthrough arg for result #r; +// result[r] += result[a]; +// } +// +void ResourceAliasAnalysisInfo::AnalyzeWhileLoop( + Operation* while_op, const BacktrackAnalysisInfo& body_info) { + // Seed the resource ID's for the results using either the resource ID of the + // passthrough arg, or unknown. We need to perform further analysis if we + // find a passthrough arg which is not the same as corresponding the result #. + llvm::SmallVector, 4> passthrough_args( + while_op->getNumResults()); + bool need_analysis = false; + for (auto result : filter_resources(while_op->getResults())) { + int result_index = result.getResultNumber(); + passthrough_args[result_index] = body_info.GetArg(result_index); + if (passthrough_args[result_index]) { + int passthru_index = passthrough_args[result_index].getValue(); + PropagateInputToOutput(while_op->getOperand(passthru_index), result); + need_analysis |= + !IsUnknownResource(result) && passthru_index != result_index; + } else { + AddValueUniqueIDMapping(result, kUnknownResourceId); + } + } + + if (!need_analysis) return; + + // We found a result that is not unknown and whose passthrough operand index + // is not the same as the result index, which means there is "crosstalk" + // between 2 or more operands. In that case, we do an iterative propagation + // of resource ID's till the results converge. + bool change = true; + while (change) { + change = false; + for (auto result : filter_resources(while_op->getResults())) { + if (IsUnknownResource(result)) continue; + // If this result has a valid passthrough arg, propagate resource ID's + // from the result of the passthrough arg + int result_index = result.getResultNumber(); + int passthru_index = passthrough_args[result_index].getValue(); + change = + PropagateInputToOutput(while_op->getResult(passthru_index), result) || + change; + } + } +} + +bool ResourceAliasAnalysisInfo::IsUnknownResource(Value resource) const { + auto it = resource_value_to_ids_.find(resource); + assert(it != resource_value_to_ids_.end() && !it->getSecond().empty()); + // The set is sorted so we only need to check the first element since + // kUnknownResourceId < 0. + static_assert(kUnknownResourceId < 0, + "kUnknownResourceId should be negative"); + return *it->getSecond().begin() == kUnknownResourceId; +} + +const llvm::SmallSet& +ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const { + assert(!IsUnknownResource(resource)); + auto it = resource_value_to_ids_.find(resource); + assert(it != resource_value_to_ids_.end() && "Unseen resource was queried"); + return it->getSecond(); +} + +const llvm::SmallSetVector& +ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const { + auto it = id_to_resource_values_.find(id); + assert(it != id_to_resource_values_.end() && "Unseen id was queried"); + return it->getSecond(); +} + +llvm::SmallSetVector ResourceAliasAnalysisInfo::GetResourceAliases( + Value resource) const { + assert(!IsUnknownResource(resource) && "Unknown resource was queried"); + llvm::SmallSetVector aliases; + for (int64_t id : GetResourceUniqueIds(resource)) { + const llvm::SmallSetVector& resources_aliasing_id = + GetUniqueIdResources(id); + aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end()); + } + // If there are resources that were marked as unknown, they alias with all + // other resources. + auto it = id_to_resource_values_.find(kUnknownResourceId); + if (it != id_to_resource_values_.end()) + aliases.insert(it->getSecond().begin(), it->getSecond().end()); + return aliases; +} + +} // namespace detail + +//===----------------------------------------------------------------------===// +// ResourceAliasAnalysis +//===----------------------------------------------------------------------===// + +ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) { + auto module = dyn_cast(op); + assert(module); + + // Analyze all regions for backtracking info. + detail::BacktrackAnalysis backtrack_analysis(module); + + // Analyze each function. + for (auto func : module.getOps()) + this->info_map_.try_emplace(func, func, backtrack_analysis); +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h new file mode 100644 index 00000000000..c965b5d7602 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h @@ -0,0 +1,120 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_ALIAS_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_ALIAS_ANALYSIS_H_ + +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/per_function_aggregate_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace TF { +namespace detail { +class BacktrackAnalysis; +class BacktrackAnalysisInfo; + +// Resource alias analysis information for a single function. +class ResourceAliasAnalysisInfo { + public: + // Constructs analysis info by analyzing the given function. + ResourceAliasAnalysisInfo(FuncOp func, + const BacktrackAnalysis& backtrack_analysis); + + ResourceAliasAnalysisInfo(ResourceAliasAnalysisInfo&&) = default; + + // Returns if the analysis fails to resolve a resource-type value. + bool IsUnknownResource(Value resource) const; + + // Returns the set unique IDs which `resource` could alias. Requires that + // IsUnknownResource(resource) == false. + const llvm::SmallSet& GetResourceUniqueIds(Value resource) const; + + // Returns the set of values that are potentially aliases of `value`. Requires + // that IsUnknownResource(resource) == false. + llvm::SmallSetVector GetResourceAliases(Value resource) const; + + private: + // Maps resource value to unique ID and vice-versa. Returns true of the + // mapping has changed. + bool AddValueUniqueIDMapping(Value value, int64_t id) { + resource_value_to_ids_[value].insert(id); + return id_to_resource_values_[id].insert(value); + } + + // Returns the set unique Values which map to `id`. + const llvm::SmallSetVector& GetUniqueIdResources(int64_t id) const; + + // Propagates the resource ID's from an input operand to a result. Returns + // true of the mapping has changed. + bool PropagateInputToOutput(const Value& operand, const OpResult& result); + + // Analyzes while loops to compute resourceID's for the loop results. + // `body_info` is the backtrack analysis info for the loop body. + void AnalyzeWhileLoop(Operation* while_op, + const BacktrackAnalysisInfo& body_info); + + // Maps each resource-type value to a set of unique IDs that it could alias. + llvm::SmallDenseMap, 8> + resource_value_to_ids_; + + // Maps each unique ID to a set of resource-type values that could alias to + // it. This is inverse of `resource_value_to_ids_` map. + llvm::SmallDenseMap, 8> + id_to_resource_values_; + + public: + static constexpr int64_t kUnknownResourceId = -1; +}; + +} // namespace detail + +// An analysis that runs on a module and maps each resource-type value to a +// set of unique IDs representing the possible resources it could alias. +// +// Note that this is not an inter-procedural or inter-regional analysis, i.e., +// each function and region are handled separately and cross-function or cross- +// region aliasing cannot be checked by this analysis. +class ResourceAliasAnalysis : public detail::PerFunctionAggregateAnalysis< + detail::ResourceAliasAnalysisInfo> { + public: + // Constructs analysis by analyzing the given module operation. + explicit ResourceAliasAnalysis(Operation* op); +}; + +// Returns a range with just resource type values from the input range +// preserved. +template +auto filter_resources(RangeT&& range) { + return llvm::make_filter_range(std::forward(range), [](Value val) { + return getElementTypeOrSelf(val.getType()).isa(); + }); +} + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_RESOURCE_ALIAS_ANALYSIS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index be203e0397e..9e78b90debc 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -45,234 +45,14 @@ limitations under the License. namespace mlir { namespace TF { - namespace { -constexpr int64_t kUnknownResourceId = -1; -constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id"; +constexpr auto kUnknownResourceId = + ResourceAliasAnalysis::Info::kUnknownResourceId; -// Returns if a VarHandleOp is anonymous, which means it always creates a new -// variable. -bool IsResourceHandleAnonymous(TF::VarHandleOp handle) { - return handle.shared_name() == tensorflow::ResourceHandle::ANONYMOUS_NAME; -} - -// Returns a string unique identifier for a non-anonymous VarHandleOp. -std::string GetVarHandleStringId(TF::VarHandleOp handle) { - auto device = handle.getAttrOfType("device"); - return absl::StrCat(handle.container().str(), "/", handle.shared_name().str(), - "/", device ? device.getValue().str() : std::string("")); -} - -// Finds a unique ID for a VarHandleOp's output. If it is anonymous, always -// creates a new ID; otherwise, tries to reuse the existing ID for the -// referenced variable if it exists, or creates a new one if not. -int64_t GetOrCreateIdForVarHandle(TF::VarHandleOp handle, int64_t* next_id, - llvm::StringMap* name_id_map) { - // Always create a new ID for anonymous handle. - if (IsResourceHandleAnonymous(handle)) return (*next_id)++; - - auto name = GetVarHandleStringId(handle); - auto emplace_res = name_id_map->try_emplace(name, *next_id); - // New ID created, increment next_id. - if (emplace_res.second) ++(*next_id); - return emplace_res.first->second; -} - -// If the return value for `func_op` at `return_index` is a pass-through of an -// argument of this function, returns the argument index; otherwise, returns -1. -int64_t FindPassthroughArgumentForReturnValue(int64_t return_index, - FuncOp func_op) { - auto value = - func_op.getBody().front().getTerminator()->getOperand(return_index); - assert(mlir::getElementTypeOrSelf(value.getType()).isa()); - int64_t arg_index = -1; - auto try_parse_arg_index = [&arg_index](Value v) { - auto resource_arg = v.dyn_cast(); - if (resource_arg) arg_index = resource_arg.getArgNumber(); - return arg_index; - }; - while (try_parse_arg_index(value) == -1) { - auto op = value.getDefiningOp(); - assert(op); - int64_t res_num = value.cast().getResultNumber(); - if (auto graph = llvm::dyn_cast(op)) { - value = graph.GetFetch().getOperand(res_num); - } else if (auto island = llvm::dyn_cast(op)) { - value = island.GetYield().getOperand(res_num); - } else if (llvm::isa(op)) { - value = op->getOperand(res_num); - } else { - return -1; - } - } - return arg_index; -} - -} // namespace - -ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) { - auto func_op = llvm::dyn_cast(op); - if (!func_op) return; - AnalyzeFunction(func_op); -} - -void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { - // This function populates resource_value_to_ids_ and id_to_resource_values_. - - // If the "tf.resource_arg_unique_id" argument attributes are present for - // resource-type arguments, respect them when choosing IDs; otherwise, they - // must not alias. - int64_t next_unique_id = 0; - const bool has_arg_unique_id_attrs = - llvm::any_of(func_op.getArguments(), [&](const BlockArgument& arg) { - return func_op.getArgAttr(arg.getArgNumber(), kResourceArgUniqueIdAttr); - }); - // Maps the kResourceArgUniqueIdAttr attribute value to the internal integer - // ID used by this pass. - llvm::SmallDenseMap attr_id_to_internal_id; - for (auto arg : func_op.getArguments()) { - if (!mlir::getElementTypeOrSelf(arg.getType()).isa()) - continue; - if (has_arg_unique_id_attrs) { - auto id_attr = func_op.getArgAttrOfType( - arg.getArgNumber(), kResourceArgUniqueIdAttr); - assert(id_attr && - "tf.resource_arg_unique_id attribute should exist on either none " - "or all arguments."); - auto emplace_res = attr_id_to_internal_id.try_emplace(id_attr.getInt(), - next_unique_id++); - AddValueUniqueIDMapping(arg, emplace_res.first->getSecond()); - } else { - AddValueUniqueIDMapping(arg, next_unique_id++); - } - } - llvm::StringMap var_handle_name_id_map; - auto forward_input_to_output = [&](const Value& operand, - const Value& result) { - if (!mlir::getElementTypeOrSelf(result.getType()).isa()) - return; - auto& result_ids = resource_value_to_ids_[result]; - auto operand_it = resource_value_to_ids_.find(operand); - assert(operand_it != resource_value_to_ids_.end() && - "A resource-type output does not have the corresponding " - "resource-type input."); - result_ids.insert(operand_it->getSecond().begin(), - operand_it->getSecond().end()); - }; - auto module = func_op.getParentOfType(); - - func_op.walk([&](Operation* op) { - if (auto var_handle = llvm::dyn_cast(op)) { - AddValueUniqueIDMapping( - var_handle.resource(), - GetOrCreateIdForVarHandle(var_handle, &next_unique_id, - &var_handle_name_id_map)); - } else if (llvm::isa(op)) { - for (auto operand_and_result : - llvm::zip(op->getOperands(), op->getResults())) { - forward_input_to_output(std::get<0>(operand_and_result), - std::get<1>(operand_and_result)); - } - } else if (auto replicate = llvm::dyn_cast(op)) { - // The nested block for ReplicateOp is handled separately in side-effect - // analysis. Inside that block, we can still treat its block arguments as - // different resources. - for (auto arg : replicate.GetBody().getArguments()) { - if (mlir::getElementTypeOrSelf(arg.getType()).isa()) { - AddValueUniqueIDMapping(arg, next_unique_id++); - } - } - } else if (auto while_op = llvm::dyn_cast(op)) { - auto body = llvm::cast(module.lookupSymbol(while_op.body())); - // If a result is a passthrough of the body input, use the corresponding - // operand's resource IDs. - for (auto result : llvm::enumerate(while_op.getResults())) { - if (!mlir::getElementTypeOrSelf(result.value().getType()) - .isa()) { - continue; - } - int64_t passthrough_operand = - FindPassthroughArgumentForReturnValue(result.index(), body); - if (passthrough_operand >= 0) { - forward_input_to_output(while_op.getOperand(passthrough_operand), - result.value()); - } else { - AddValueUniqueIDMapping(result.value(), kUnknownResourceId); - } - } - } else if (auto if_op = llvm::dyn_cast(op)) { - auto then_branch = - llvm::cast(module.lookupSymbol(if_op.then_branch())); - auto else_branch = - llvm::cast(module.lookupSymbol(if_op.else_branch())); - // If a result is a passthrough of both branches' inputs, merge the - // resource IDs of corresponding operands for the two inputs. - for (auto result : llvm::enumerate(if_op.getResults())) { - if (!mlir::getElementTypeOrSelf(result.value().getType()) - .isa()) { - continue; - } - int64_t passthrough_then_arg = - FindPassthroughArgumentForReturnValue(result.index(), then_branch); - int64_t passthrough_else_arg = - FindPassthroughArgumentForReturnValue(result.index(), else_branch); - if (passthrough_then_arg >= 0 && passthrough_else_arg >= 0) { - forward_input_to_output(if_op.getOperand(passthrough_then_arg + 1), - result.value()); - forward_input_to_output(if_op.getOperand(passthrough_else_arg + 1), - result.value()); - } else { - AddValueUniqueIDMapping(result.value(), kUnknownResourceId); - } - } - } else { - for (auto result : op->getResults()) { - if (!mlir::getElementTypeOrSelf(result.getType()) - .isa()) - continue; - AddValueUniqueIDMapping(result, kUnknownResourceId); - } - } - }); -} - -bool ResourceAliasAnalysis::IsUnknownResource(const Value resource) const { - auto it = resource_value_to_ids_.find(resource); - assert(it != resource_value_to_ids_.end() && !it->getSecond().empty()); - // The set is sorted so we only need to check the first element since - // kUnknownResourceId < 0. - static_assert(kUnknownResourceId < 0, - "kUnknownResourceId should be negative"); - return *it->getSecond().begin() == kUnknownResourceId; -} - -const llvm::SmallSet& ResourceAliasAnalysis::GetResourceUniqueIds( - const Value resource) const { - auto it = resource_value_to_ids_.find(resource); - assert(it != resource_value_to_ids_.end() && "Unseen resource was queried"); - return it->getSecond(); -} - -const llvm::SmallSetVector& -ResourceAliasAnalysis::GetUniqueIdResources(const int64_t id) const { - auto it = id_to_resource_values_.find(id); - assert(it != id_to_resource_values_.end() && "Unseen id was queried"); - return it->getSecond(); -} - -llvm::SmallSetVector ResourceAliasAnalysis::GetResourceAliases( - const Value resource) const { - assert(!IsUnknownResource(resource) && "Unseen resource was queried"); - llvm::SmallSetVector aliases; - for (int64_t id : GetResourceUniqueIds(resource)) { - const llvm::SmallSetVector& resources_aliasing_id = - GetUniqueIdResources(id); - aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end()); - } - return aliases; -} -namespace { +//===----------------------------------------------------------------------===// +// SideEffectAnalysisInfo helper functions. +//===----------------------------------------------------------------------===// // Returns a set that contains only kUnknownResourceId. llvm::SmallDenseSet UnknownResourceSet() { @@ -284,7 +64,7 @@ llvm::SmallDenseSet UnknownResourceSet() { // Returns all resources that could be accessed by op, or UnknownResourceSet() // if we cannot find all of them. llvm::SmallDenseSet FindAccessedResources( - Operation* op, const ResourceAliasAnalysis& alias_analysis) { + Operation* op, const ResourceAliasAnalysis::Info& alias_analysis) { llvm::SmallDenseSet resources; for (auto operand : op->getOperands()) { @@ -311,7 +91,6 @@ llvm::SmallDenseSet FindAccessedResources( // TODO(yuanzx): Define this information in a different place. Currently we use // tensorflow/compiler/tf2xla/resource_operation_table.h. const tensorflow::XlaResourceOpInfo* GetResourceInfoForOp(Operation* op) { - auto op_name = op->getName().getStringRef().str(); if (op->getName().getDialect() != TF::TensorFlowDialect::getDialectNamespace()) { return nullptr; @@ -329,7 +108,7 @@ bool OpIsReadOnly(Operation* op) { // Returns if `op` is a resource declaration. bool OpIsDeclaration(Operation* op, - const ResourceAliasAnalysis& alias_analysis) { + const ResourceAliasAnalysis::Info& alias_analysis) { // TODO(yuanzx): Add other types of resources. return llvm::isa(op) || (llvm::isa(op) && @@ -370,8 +149,13 @@ bool OpIsKnownToHaveNoSideEffect(Operation* op) { } // namespace -void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op, - bool read_only) { +namespace detail { +//===----------------------------------------------------------------------===// +// SideEffectAnalysisInfo +//===----------------------------------------------------------------------===// + +void SideEffectAnalysisInfo::TrackAccess(int64_t resource_id, Operation* op, + bool read_only) { if (resource_id == kUnknownResourceId) { if (read_only) { // New unknown read is not tracked by any known resource access. @@ -402,9 +186,9 @@ void SideEffectAnalysis::TrackAccess(int64_t resource_id, Operation* op, } } -void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id, - Operation* op, - bool read_only) { +void SideEffectAnalysisInfo::AddPredecessorsForAccess(int64_t resource_id, + Operation* op, + bool read_only) { auto it = per_resource_access_info_.find(resource_id); if (it == per_resource_access_info_.end()) return; const auto& access_info = it->getSecond(); @@ -420,8 +204,8 @@ void SideEffectAnalysis::AddPredecessorsForAccess(int64_t resource_id, } } -void SideEffectAnalysis::AnalyzeFunction( - FuncOp func_op, const ResourceAliasAnalysis& alias_analysis) { +void SideEffectAnalysisInfo::AnalyzeFunction( + FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis) { // AnalyzeRegion() recursively analyzes the function body, and only populates // control_predecessors_. AnalyzeRegion(&func_op.getBody(), alias_analysis); @@ -448,8 +232,8 @@ void SideEffectAnalysis::AnalyzeFunction( } } -void SideEffectAnalysis::AnalyzeRegion( - Region* region, const ResourceAliasAnalysis& alias_analysis) { +void SideEffectAnalysisInfo::AnalyzeRegion( + Region* region, const TF::ResourceAliasAnalysis::Info& alias_analysis) { // This function populates control_predecessors_ by walking through the // region, and tracking resource accesses in per_resource_access_info_. @@ -476,13 +260,12 @@ void SideEffectAnalysis::AnalyzeRegion( // different nested regions separately. for (auto& block : *region) { for (auto& op : block) { - if (op.getNumRegions() > 0) { - llvm::SmallVector child_analyses; - for (auto& child_region : op.getRegions()) { - child_analyses.emplace_back(); - child_analyses.back().AnalyzeRegion(&child_region, alias_analysis); - } - ConsumeChildAnalyses(std::move(child_analyses)); + for (Region& child : op.getRegions()) { + SideEffectAnalysisInfo child_analysis(&child, alias_analysis); + // Moves the control_predecessors_ fields in child region to current + // region + for (auto& entry : child_analysis.control_predecessors_) + control_predecessors_[entry.first] = std::move(entry.second); } // We do not need explicit control edges for declaration ops. @@ -529,16 +312,8 @@ void SideEffectAnalysis::AnalyzeRegion( } } -void SideEffectAnalysis::ConsumeChildAnalyses( - llvm::SmallVector&& children) { - for (auto& child : children) { - for (auto& entry : child.control_predecessors_) { - control_predecessors_[entry.getFirst()] = std::move(entry.getSecond()); - } - } -} - -llvm::SmallVector SideEffectAnalysis::DirectControlPredecessors( +llvm::SmallVector +SideEffectAnalysisInfo::DirectControlPredecessors( Operation* op, llvm::function_ref filter) const { llvm::SmallVector result; auto it = sorted_control_predecessors_.find(op); @@ -550,7 +325,8 @@ llvm::SmallVector SideEffectAnalysis::DirectControlPredecessors( return result; } -llvm::SmallVector SideEffectAnalysis::DirectControlSuccessors( +llvm::SmallVector +SideEffectAnalysisInfo::DirectControlSuccessors( Operation* op, llvm::function_ref filter) const { llvm::SmallVector result; auto it = sorted_control_successors_.find(op); @@ -561,12 +337,19 @@ llvm::SmallVector SideEffectAnalysis::DirectControlSuccessors( } return result; } +} // namespace detail SideEffectAnalysis::SideEffectAnalysis(Operation* op) { - auto func_op = llvm::dyn_cast(op); - if (!func_op) return; - ResourceAliasAnalysis alias_analysis(op); - AnalyzeFunction(func_op, alias_analysis); + auto module = dyn_cast(op); + assert(module); + + // Analyze entire module for alias analysis info. + ResourceAliasAnalysis alias_analysis(module); + + // Analyze all functions. + for (auto func : module.getOps()) + this->info_map_.try_emplace(func, func, + alias_analysis.GetAnalysisForFunc(func)); } } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h index a318c6667c6..c92c6e1882c 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_ANALYSIS_SIDE_EFFECT_ANALYSIS_H_ +#include #include #include @@ -23,78 +24,33 @@ limitations under the License. #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" -#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" namespace mlir { namespace TF { +namespace detail { -// An analysis that runs on a function and maps each resource-type value to a -// set of unique int64_t IDs representing the possible resources it could alias. -// -// If there are nested regions, each region is handled separately. This means -// cross-region aliasing cannot be checked by this analysis. -class ResourceAliasAnalysis { +// Side effect analysis info for a single function. +class SideEffectAnalysisInfo { public: - explicit ResourceAliasAnalysis(Operation* op); - ~ResourceAliasAnalysis() = default; - ResourceAliasAnalysis(ResourceAliasAnalysis&&) = default; + SideEffectAnalysisInfo() = default; - // Returns if the analysis fails to resolve a resource-type value. - bool IsUnknownResource(const Value resource) const; - - // Returns the set unique IDs which `resource` could alias. Requires that - // IsUnknownResource(resource) == true. - const llvm::SmallSet& GetResourceUniqueIds( - const Value resource) const; - - // Returns the set of values that are potentially aliases of `value`. Requires - // that IsUnknownResource(resource) == true. - llvm::SmallSetVector GetResourceAliases(const Value resource) const; - - private: - ResourceAliasAnalysis() = default; - - // Runs the analysis on `func_op` and populates two way resource values to - // unique ID mapping. - void AnalyzeFunction(FuncOp func_op); - - // Maps resource value to unique ID and vice-versa. - void AddValueUniqueIDMapping(Value value, int64_t id) { - resource_value_to_ids_[value].insert(id); - id_to_resource_values_[id].insert(value); + // Constructs analysis info by analyzing the given function. + SideEffectAnalysisInfo( + FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis) { + AnalyzeFunction(func_op, alias_analysis); } - // Returns the set unique Values which map to `id`. - const llvm::SmallSetVector& GetUniqueIdResources(int64_t id) const; + // Constructs analysis info by analyzing the given region. + SideEffectAnalysisInfo( + Region* region, const TF::ResourceAliasAnalysis::Info& alias_analysis) { + AnalyzeRegion(region, alias_analysis); + } - // Maps each resource-type value to a set of unique IDs that it could alias. - llvm::SmallDenseMap, 8> - resource_value_to_ids_; - - // Maps each unique ID to a set of resource-type values that could alias to - // it. This is inverse of `resource_value_to_ids_` map. - llvm::SmallDenseMap, 8> - id_to_resource_values_; -}; - -// An analysis that runs on a function and infers the control predecessors and -// successors for each op, based on side-effects on known and unknown resources. -// Side-effecting ops on unknown resources are conservatively treated as -// interfering with all known resource op accesses. It distinguishes accesses -// based on whether they are read-only, and read-only ops do not interfere with -// each other. -// -// If there are nested regions, each region is handled separately, and control -// dependencies are only tracked for ops under the same parent op. -class SideEffectAnalysis { - public: - explicit SideEffectAnalysis() = default; - explicit SideEffectAnalysis(Operation* op); - SideEffectAnalysis(SideEffectAnalysis&& other) = default; - ~SideEffectAnalysis() = default; + SideEffectAnalysisInfo(SideEffectAnalysisInfo&&) = default; // Returns a vector of ops that are direct control predecessors of `op`, // sorted in program order. If `filter` is provided, only predecessors that @@ -103,9 +59,9 @@ class SideEffectAnalysis { Operation* op, llvm::function_ref filter = nullptr) const; - // Returns a vector of ops that are direct control successors of `op`, sorted - // in program order. If `filter` is provided, only successors that pass the - // filter (returning true) will be included. + // Returns a vector of ops that are direct control successors of `op`, + // sorted in program order. If `filter` is provided, only successors that + // pass the filter (returning true) will be included. llvm::SmallVector DirectControlSuccessors( Operation* op, llvm::function_ref filter = nullptr) const; @@ -114,16 +70,11 @@ class SideEffectAnalysis { // Runs the analysis on `func_op` and populates sorted_control_predecessors_ // and sorted_control_successors_. void AnalyzeFunction(FuncOp func_op, - const ResourceAliasAnalysis& alias_analysis); + const TF::ResourceAliasAnalysis::Info& alias_analysis); // Runs the analysis on `region` and populates control_predecessors_. void AnalyzeRegion(Region* region, - const ResourceAliasAnalysis& alias_analysis); - - // Moves the control_predecessors_ fields in `children` analyses to this - // current analysis. - void ConsumeChildAnalyses( - llvm::SmallVector&& children); + const TF::ResourceAliasAnalysis::Info& alias_analysis); // Updates control_predecessors_ for `op` that is being visited, on the given // `resource_id`. @@ -159,10 +110,29 @@ class SideEffectAnalysis { // write for a the current write being analyzed. bool tracked_last_unknown_write_for_write = false; }; + llvm::SmallDenseMap per_resource_access_info_; }; +} // namespace detail + +// An analysis that runs on a function and infers the control predecessors and +// successors for each op, based on side-effects on known and unknown resources. +// Side-effecting ops on unknown resources are conservatively treated as +// interfering with all known resource op accesses. It distinguishes accesses +// based on whether they are read-only, and read-only ops do not interfere with +// each other. +// +// If there are nested regions, each region is handled separately, and control +// dependencies are only tracked for ops under the same parent op. +class SideEffectAnalysis : public detail::PerFunctionAggregateAnalysis< + detail::SideEffectAnalysisInfo> { + public: + // Constructs analysis by analyzing the given module operation. + explicit SideEffectAnalysis(Operation* op); +}; + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index ffd9c149d2d..66447995709 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -102,8 +102,6 @@ class MlirTensor : public TracingTensorHandle { return type; } - void Release() override { delete this; } - Value getValue() { return value_; } // For LLVM style RTTI. @@ -564,7 +562,7 @@ Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) { } PassManager pm(func_.getContext()); pm.addNestedPass(CreateFunctionalToExecutorDialectConversionPass()); - pm.addNestedPass(CreateBreakUpIslandsPass()); + pm.addPass(CreateBreakUpIslandsPass()); // In case of failure, the `diag_handler` converts MLIR errors emitted to // the MLIRContext into a tensorflow::Status. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 77008b55672..5345000b4bd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -101,7 +101,8 @@ bool BlockWrapsSingleOp(Block* block) { } // end anonymous namespace TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) - : Dialect(/*name=*/"tf_device", context) { + : Dialect(/*name=*/"tf_device", context, + TypeID::get()) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc" @@ -118,31 +119,6 @@ TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext* context) // operation results are perfectly forwarded to the launch return. bool LaunchOp::WrapsSingleOp() { return BlockWrapsSingleOp(&GetBody()); } -//===----------------------------------------------------------------------===// -// tf_device.return -//===----------------------------------------------------------------------===// - -namespace { -ParseResult ParseReturnOp(OpAsmParser* parser, OperationState* state) { - llvm::SmallVector op_info; - llvm::SmallVector types; - llvm::SMLoc loc = parser->getCurrentLocation(); - return failure(parser->parseOperandList(op_info) || - (!op_info.empty() && parser->parseColonTypeList(types)) || - parser->resolveOperands(op_info, types, loc, state->operands)); -} - -void Print(ReturnOp op, OpAsmPrinter* p) { - *p << op.getOperationName(); - if (op.getNumOperands() > 0) { - *p << ' '; - p->printOperands(op.getOperands()); - *p << " : "; - interleaveComma(op.getOperandTypes(), *p); - } -} -} // anonymous namespace - //===----------------------------------------------------------------------===// // tf_device.parallel_execute //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h index 4c20d1ccc4f..688c8ca5715 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -35,6 +36,7 @@ namespace tf_device { // XlaRun. class TensorFlowDeviceDialect : public Dialect { public: + static StringRef getDialectNamespace() { return "tf_device"; } // Constructing TensorFlowDevice dialect under an non-null MLIRContext. explicit TensorFlowDeviceDialect(MLIRContext* context); }; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index 3a92e3237dc..d94a37d9b02 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -104,8 +104,7 @@ The `tf_device.return` operation terminates and returns values from a }]> ]; - let parser = [{ return Parse$cppClass(&parser, &result); }]; - let printer = [{ return Print(*this, &p); }]; + let assemblyFormat = "attr-dict ($results^ `:` type($results))?"; } def TfDevice_LaunchFuncOp : TfDevice_Op<"launch_func", []> { @@ -354,7 +353,10 @@ This op is used for outlining a cluster. ); let extraClassDeclaration = [{ - StringRef getFunc() { return func(); } + // returns the function that this operation will launch. + FuncOp getFunc() { + return SymbolTable::lookupNearestSymbolFrom(*this, func()); + } }]; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 1b1d5ba6f3b..9c2968fab37 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -92,7 +92,8 @@ struct TensorFlowExecutorOpFolderDialectInterface } // namespace TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context) - : Dialect(/*name=*/"tf_executor", context) { + : Dialect(/*name=*/"tf_executor", context, + TypeID::get()) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc.inc" @@ -190,14 +191,15 @@ LogicalResult Verify(GraphOp graph) { for (int i : llvm::seq(0, fetch.getNumOperands())) { Value operand = fetch.getOperand(i); // Break out of the loop at the first control operand encountered. + const int64_t num_results = graph.getNumResults(); if (operand.getType().isa()) { - if (i != graph.getNumResults()) + if (i != num_results) return fetch.emitOpError() << "operand #" << i << " is a control type, can't be bound to a graph result"; break; } - if (i >= graph.getNumResults()) + if (i >= num_results) return fetch.emitOpError() << "operand #" << i << " does not have a graph results to bind"; if (graph.getResult(i).getType() != operand.getType()) @@ -301,8 +303,8 @@ bool IslandOp::WrapsSingleOp() { namespace { LogicalResult Verify(IslandOp island) { - if (island.GetBody().empty()) - return island.emitOpError() << "expects a non-empty body"; + if (!island.GetBody().args_empty()) + return island.emitOpError() << "expects body without any arguments"; Operation &yield = island.GetBody().back(); if (!isa(yield)) @@ -311,7 +313,8 @@ LogicalResult Verify(IslandOp island) { // Ensure that the yield terminator operands matches the island results type. int result_count = island.getNumResults() - 1; // -1 for the control token - if (yield.getNumOperands() != result_count) + const int num_operands = yield.getNumOperands(); + if (num_operands != result_count) return yield.emitOpError() << "has " << yield.getNumOperands() << " operand, but island returns " << result_count; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h index 3bb30f16c3d..61358172d6d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h @@ -35,6 +35,7 @@ namespace tf_executor { class TensorFlowExecutorDialect : public Dialect { public: + static StringRef getDialectNamespace() { return "tf_executor"; } explicit TensorFlowExecutorDialect(MLIRContext *context); // Parses a type registered to this dialect. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index a0e73f116cf..bf8d7015b46 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -87,7 +87,7 @@ tf.math.acosh(x) ==> [nan nan 0. 0.62236255 5.9914584 9.903487 inf] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, SameOperandsAndResultElementType]>, +def TF_AddOp : TF_Op<"Add", [NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns x + y element-wise."; @@ -136,7 +136,7 @@ Inputs must be of same size and shape. let hasFolder = 1; } -def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, SameOperandsAndResultElementType]>, +def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>, WithBroadcastableBinOpBuilder { let summary = "Returns x + y element-wise."; @@ -648,7 +648,7 @@ tf.math.atan(y) # [1.047, 0.785] = x TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_Atan2Op : TF_Op<"Atan2", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_Atan2Op : TF_Op<"Atan2", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = [{ Computes arctangent of `y/x` element-wise, respecting signs of the arguments. @@ -725,6 +725,30 @@ window in `value`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_AvgPool3DOp : TF_Op<"AvgPool3D", [NoSideEffect]> { + let summary = "Performs 3D average pooling on the input."; + + let description = [{ +Each entry in `output` is the mean of the corresponding size `ksize` window in +`value`. + }]; + + let arguments = (ins + TF_FpTensor:$input, + + Confined]>:$ksize, + Confined]>:$strides, + TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding, + DefaultValuedAttr, "NDHWC">:$data_format + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AvgPool3DGradOp : TF_Op<"AvgPool3DGrad", [NoSideEffect]> { let summary = "Computes gradients of average pooling function."; @@ -765,7 +789,7 @@ def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> { TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; } -def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect, SameOperandsAndResultElementType]> { +def TF_BatchMatMulOp : TF_Op<"BatchMatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = "Multiplies slices of two tensors in batches."; let description = [{ @@ -806,7 +830,7 @@ It is computed as: let hasCanonicalizer = 1; } -def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect, SameOperandsAndResultElementType]> { +def TF_BatchMatMulV2Op : TF_Op<"BatchMatMulV2", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = "Multiplies slices of two tensors in batches."; let description = [{ @@ -1326,48 +1350,6 @@ then the output will be TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_CaseOp : TF_Op<"Case", []> { - let summary = [{ -An n-way switch statement which calls a single branch function. - }]; - - let description = [{ -An n-way switch statement, implementing the following: - ``` - switch (branch_index) { - case 0: - output = branches[0](input); - break; - case 1: - output = branches[1](input); - break; - ... - case [[nbranches-1]]: - default: - output = branches[nbranches-1](input); - break; - } - ``` - }]; - - let arguments = (ins - I32Tensor:$branch_index, - Variadic:$input, - - Confined]>:$branches, - DefaultValuedAttr:$output_shapes - ); - - let results = (outs - Variadic:$output - ); - - TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; - TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; - - let hasCanonicalizer = 1; -} - def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Cast x of type SrcT to y of DstT."; @@ -1422,7 +1404,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultElementType]> { +def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = "Clips tensor values to a specified min and max."; let description = [{ @@ -1534,6 +1516,29 @@ Mutually reduces multiple tensors of identical type and shape. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CollectiveReduceV2Op : TF_Op<"CollectiveReduceV2", []> { + let summary = [{ +Mutually reduces multiple tensors of identical type and shape. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, I32, I64]>:$input, + I32Tensor:$group_size, + I32Tensor:$group_key, + I32Tensor:$instance_key, + + TF_AnyStrAttrOf<["Min", "Max", "Mul", "Add"]>:$merge_op, + TF_AnyStrAttrOf<["Id", "Div"]>:$final_op, + DefaultValuedAttr:$communication_hint + ); + + let results = (outs + TensorOf<[F16, F32, F64, I32, I64]>:$data + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> { let summary = "Converts two real numbers to a complex number."; @@ -1664,6 +1669,8 @@ def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> { let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; } def TF_ConjOp : TF_Op<"Conj", [NoSideEffect, SameOperandsAndResultType]> { @@ -2469,7 +2476,7 @@ Computes Psi, the derivative of Lgamma (the log of the absolute value of TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns x / y element-wise."; @@ -2494,7 +2501,7 @@ def TF_DivOp : TF_Op<"Div", [NoSideEffect, ResultsBroadcastableShape, SameOperan let hasFolder = 1; } -def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_DivNoNanOp : TF_Op<"DivNoNan", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if the denominator is zero."; @@ -3374,7 +3381,7 @@ def TF_FloorDivOp : TF_Op<"FloorDiv", [NoSideEffect, ResultsBroadcastableShape]> TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_FloorModOp : TF_Op<"FloorMod", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_FloorModOp : TF_Op<"FloorMod", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = [{ Returns element-wise remainder of division. When `x < 0` xor `y < 0` is @@ -3540,51 +3547,6 @@ The size of 1D Tensors matches the dimension C of the 4D Tensors. }]; } -def TF_FusedBatchNormV3Op : TF_Op<"FusedBatchNormV3", [NoSideEffect, TF_FoldOperandsTransposeInterface, TF_LayoutSensitiveInterface]> { - let summary = "Batch normalization."; - - let description = [{ -Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". -The size of 1D Tensors matches the dimension C of the 4D Tensors. - }]; - - let arguments = (ins - TensorOf<[BF16, F16, F32]>:$x, - F32Tensor:$scale, - F32Tensor:$offset, - F32Tensor:$mean, - F32Tensor:$variance, - - DefaultValuedAttr:$epsilon, - DefaultValuedAttr:$exponential_avg_factor, - DefaultValuedAttr:$data_format, - DefaultValuedAttr:$is_training - ); - - let results = (outs - TensorOf<[BF16, F16, F32]>:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2, - F32Tensor:$reserve_space_3 - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; - - let extraClassDeclaration = [{ - // TF_FoldOperandsTransposeInterface: - SmallVector GetLayoutDependentArgs() { return {0}; } - SmallVector GetLayoutDependentResults() { return {0}; } - LogicalResult FoldOperandsPermutation(ArrayRef permutation); - - // TF_LayoutSensitiveInterface: - StringRef GetOptimalLayout(const RuntimeDevices& devices); - LogicalResult UpdateDataFormat(StringRef data_format); - }]; -} - def TF_GatherOp : TF_Op<"Gather", [NoSideEffect]> { let summary = "Gather slices from `params` according to `indices`."; @@ -4111,7 +4073,7 @@ def ApplyG(op, dy, _): TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; } -def TF_IgammaOp : TF_Op<"Igamma", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_IgammaOp : TF_Op<"Igamma", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = [{ Compute the lower regularized incomplete Gamma function `P(a, x)`. @@ -4145,7 +4107,7 @@ Gamma function. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Computes the gradient of `igamma(a, x)` wrt `a`."; @@ -4161,7 +4123,7 @@ def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableS TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_IgammacOp : TF_Op<"Igammac", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_IgammacOp : TF_Op<"Igammac", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = [{ Compute the upper regularized incomplete Gamma function `Q(a, x)`. @@ -4252,6 +4214,29 @@ Where to extract the key and value from a line is specified by `key_index` and let results = (outs); } +def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> { + let summary = "Updates specified rows 'i' with values 'v'."; + + let description = [{ +Computes `x[i, :] = v; return x`. + +Originally this function is mutative however for compilation we make this +operation create / operate on a copy of `x`. + }]; + + let arguments = (ins + TF_Tensor:$x, + I32Tensor:$i, + TF_Tensor:$v + ); + + let results = (outs + TF_Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_InvOp : TF_Op<"Inv", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the reciprocal of x element-wise."; @@ -4772,7 +4757,7 @@ tf.math.log(x) ==> [-inf, -0.6931472, 0. , 1.609438] let hasCanonicalizer = 1; } -def TF_Log1pOp : TF_Op<"Log1p", [NoSideEffect, SameOperandsAndResultType]> { +def TF_Log1pOp : TF_Op<"Log1p", [NoSideEffect, SameOperandsAndResultType, TF_CwiseUnary]> { let summary = "Computes natural logarithm of (1 + x) element-wise."; let description = [{ @@ -4928,7 +4913,7 @@ def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> { ); } -def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, SameOperandsAndResultElementType]> { +def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = [{ Multiply the matrix "a" by the matrix "b". }]; @@ -5066,6 +5051,126 @@ which has shape (2, 4, 4) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MatrixDiagPartV3Op : TF_Op<"MatrixDiagPartV3", [NoSideEffect]> { + let summary = "Returns the batched diagonal part of a batched tensor."; + + let description = [{ +Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched +`input`. + +Assume `input` has `r` dimensions `[I, J, ..., L, M, N]`. +Let `max_diag_len` be the maximum length among all diagonals to be extracted, +`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))` +Let `num_diags` be the number of diagonals to extract, +`num_diags = k[1] - k[0] + 1`. + +If `num_diags == 1`, the output tensor is of rank `r - 1` with shape +`[I, J, ..., L, max_diag_len]` and values: + +``` +diagonal[i, j, ..., l, n] + = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N, + padding_value ; otherwise. +``` +where `y = max(-k[1], 0)`, `x = max(k[1], 0)`. + +Otherwise, the output tensor has rank `r` with dimensions +`[I, J, ..., L, num_diags, max_diag_len]` with values: + +``` +diagonal[i, j, ..., l, m, n] + = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N, + padding_value ; otherwise. +``` +where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`. + +`offset` is zero except when the alignment of the diagonal is to the right. +``` +offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT} + and `d >= 0`) or + (`align` in {LEFT_RIGHT, RIGHT_RIGHT} + and `d <= 0`) + 0 ; otherwise +``` +where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`. + +The input must be at least a matrix. + +For example: + +``` +input = np.array([[[1, 2, 3, 4], # Input shape: (2, 3, 4) + [5, 6, 7, 8], + [9, 8, 7, 6]], + [[5, 4, 3, 2], + [1, 2, 3, 4], + [5, 6, 7, 8]]]) + +# A main diagonal from each batch. +tf.matrix_diag_part(input) ==> [[1, 6, 7], # Output shape: (2, 3) + [5, 2, 7]] + +# A superdiagonal from each batch. +tf.matrix_diag_part(input, k = 1) + ==> [[2, 7, 6], # Output shape: (2, 3) + [4, 3, 8]] + +# A band from each batch. +tf.matrix_diag_part(input, k = (-1, 2)) + ==> [[[0, 3, 8], # Output shape: (2, 4, 3) + [2, 7, 6], + [1, 6, 7], + [5, 8, 0]], + [[0, 3, 4], + [4, 3, 8], + [5, 2, 7], + [1, 6, 0]]] + +# LEFT_RIGHT alignment. +tf.matrix_diag_part(input, k = (-1, 2), align="LEFT_RIGHT") + ==> [[[3, 8, 0], # Output shape: (2, 4, 3) + [2, 7, 6], + [1, 6, 7], + [0, 5, 8]], + [[3, 4, 0], + [4, 3, 8], + [5, 2, 7], + [0, 1, 6]]] + +# max_diag_len can be shorter than the main diagonal. +tf.matrix_diag_part(input, k = (-2, -1)) + ==> [[[5, 8], + [9, 0]], + [[1, 6], + [5, 0]]] + +# padding_value = 9 +tf.matrix_diag_part(input, k = (1, 3), padding_value = 9) + ==> [[[9, 9, 4], # Output shape: (2, 3, 3) + [9, 3, 8], + [2, 7, 6]], + [[9, 9, 2], + [9, 3, 4], + [4, 3, 8]]] + +``` + }]; + + let arguments = (ins + TF_Tensor:$input, + I32Tensor:$k, + TF_Tensor:$padding_value, + + DefaultValuedAttr, "RIGHT_LEFT">:$align + ); + + let results = (outs + TF_Tensor:$diagonal + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_MatrixDiagV2Op : TF_Op<"MatrixDiagV2", [NoSideEffect]> { let summary = [{ Returns a batched diagonal tensor with given batched diagonal values. @@ -5692,7 +5797,7 @@ def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> { }]; } -def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns the max of x and y (i.e. x > y ? x : y) element-wise."; @@ -5766,7 +5871,7 @@ retained with length 1. TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; } -def TF_MinimumOp : TF_Op<"Minimum", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_MinimumOp : TF_Op<"Minimum", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns the min of x and y (i.e. x < y ? x : y) element-wise."; @@ -5899,7 +6004,7 @@ graph_def = foo.get_concrete_function(tf.TensorSpec([10], tf.float32), tf.Tensor TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; } -def TF_ModOp : TF_Op<"Mod", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_ModOp : TF_Op<"Mod", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = [{ Returns element-wise remainder of division. This emulates C semantics in that @@ -5925,7 +6030,7 @@ the result here is consistent with a truncating divide. E.g. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>, WithBroadcastableBinOpBuilder { let summary = "Returns x * y element-wise."; @@ -5971,7 +6076,7 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_MultinomialOp : TF_Op<"Multinomial", []> { +def TF_MultinomialOp : TF_Op<"Multinomial", [TF_CannotDuplicate]> { let summary = "Draws samples from a multinomial distribution."; let arguments = (ins @@ -6332,6 +6437,8 @@ This is the opposite of `unpack`. let verifier = [{ return Verify(*this); }]; + + let hasFolder = 1; } def TF_PadOp : TF_Op<"Pad", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { @@ -6426,7 +6533,36 @@ pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0] TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>; } -def TF_PowOp : TF_Op<"Pow", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_ParameterizedTruncatedNormalOp : TF_Op<"ParameterizedTruncatedNormal", [TF_CannotDuplicate]> { + let summary = [{ +Outputs random values from a normal distribution. The parameters may each be a + }]; + + let description = [{ +scalar which applies to the entire output, or a vector of length shape[0] which +stores the parameters for each batch. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_FpTensor:$means, + TF_FpTensor:$stdevs, + TF_FpTensor:$minvals, + TF_FpTensor:$maxvals, + + DefaultValuedAttr:$seed, + DefaultValuedAttr:$seed2 + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<1>; +} + +def TF_PowOp : TF_Op<"Pow", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Computes the power of one value to another."; @@ -6809,6 +6945,33 @@ array([0.6666667, 1. , 1. ], dtype=float32) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_RandomGammaOp : TF_Op<"RandomGamma", [TF_CannotDuplicate]> { + let summary = [{ +Outputs random values from the Gamma distribution(s) described by alpha. + }]; + + let description = [{ +This op uses the algorithm by Marsaglia et al. to acquire samples via +transformation-rejection from pairs of uniform and normal random variables. +See http://dl.acm.org/citation.cfm?id=358414 + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TensorOf<[F16, F32, F64]>:$alpha, + + DefaultValuedAttr:$seed, + DefaultValuedAttr:$seed2 + ); + + let results = (outs + TensorOf<[F16, F32, F64]>:$output + ); + + TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = [{ @@ -6827,7 +6990,60 @@ Computes the derivative of a Gamma random sample w.r.t. `alpha`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_RandomShuffleOp : TF_Op<"RandomShuffle", [SameOperandsAndResultType]> { +def TF_RandomPoissonOp : TF_Op<"RandomPoisson", [TF_CannotDuplicate]> { + let summary = "Use RandomPoissonV2 instead."; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TensorOf<[F16, F32, F64]>:$rate, + + DefaultValuedAttr:$seed, + DefaultValuedAttr:$seed2 + ); + + let results = (outs + TensorOf<[F16, F32, F64]>:$output + ); + + TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<1>; +} + +def TF_RandomPoissonV2Op : TF_Op<"RandomPoissonV2", [TF_CannotDuplicate]> { + let summary = [{ +Outputs random values from the Poisson distribution(s) described by rate. + }]; + + let description = [{ +This op uses two algorithms, depending on rate. If rate >= 10, then +the algorithm by Hormann is used to acquire samples via +transformation-rejection. +See http://www.sciencedirect.com/science/article/pii/0167668793909974. + +Otherwise, Knuth's algorithm is used to acquire samples via multiplying uniform +random variables. +See Donald E. Knuth (1969). Seminumerical Algorithms. The Art of Computer +Programming, Volume 2. Addison Wesley + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TensorOf<[F16, F32, F64, I32, I64]>:$rate, + + DefaultValuedAttr:$seed, + DefaultValuedAttr:$seed2 + ); + + let results = (outs + TensorOf<[F16, F32, F64, I32, I64]>:$output + ); + + TF_DerivedOperandTypeAttr R = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr S = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_RandomShuffleOp : TF_Op<"RandomShuffle", [SameOperandsAndResultType, TF_CannotDuplicate]> { let summary = "Randomly shuffles a tensor along its first dimension."; let description = [{ @@ -6856,7 +7072,7 @@ The tensor is shuffled along dimension 0, such that each `value[j]` is mapped TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_RandomStandardNormalOp : TF_Op<"RandomStandardNormal", []> { +def TF_RandomStandardNormalOp : TF_Op<"RandomStandardNormal", [TF_CannotDuplicate]> { let summary = "Outputs random values from a normal distribution."; let description = [{ @@ -6878,7 +7094,7 @@ The generated values will have mean 0 and standard deviation 1. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } -def TF_RandomUniformOp : TF_Op<"RandomUniform", []> { +def TF_RandomUniformOp : TF_Op<"RandomUniform", [TF_CannotDuplicate]> { let summary = "Outputs random values from a uniform distribution."; let description = [{ @@ -6905,7 +7121,37 @@ lower bound 0 is included in the range, while the upper bound 1 is excluded. }]; } -def TF_RangeOp : TF_Op<"Range", [NoSideEffect, SameOperandsAndResultElementType]> { +def TF_RandomUniformIntOp : TF_Op<"RandomUniformInt", [TF_CannotDuplicate]> { + let summary = "Outputs random integers from a uniform distribution."; + + let description = [{ +The generated values are uniform integers in the range `[minval, maxval)`. +The lower bound `minval` is included in the range, while the upper bound +`maxval` is excluded. + +The random integers are slightly biased unless `maxval - minval` is an exact +power of two. The bias is small for values of `maxval - minval` significantly +smaller than the range of the output (either `2^32` or `2^64`). + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$minval, + TF_I32OrI64Tensor:$maxval, + + DefaultValuedAttr:$seed, + DefaultValuedAttr:$seed2 + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tout = TF_DerivedOperandTypeAttr<1>; +} + +def TF_RangeOp : TF_Op<"Range", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = "Creates a sequence of numbers."; let description = [{ @@ -6940,6 +7186,28 @@ tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15] ]; } +def TF_RangeDatasetOp : TF_Op<"RangeDataset", []> { + let summary = [{ +Creates a dataset with a range of values. Corresponds to python's xrange. + }]; + + let description = [{ + }]; + + let arguments = (ins + I64Tensor:$start, + I64Tensor:$stop, + I64Tensor:$step, + + Confined]>:$output_types, + Confined]>:$output_shapes + ); + + let results = (outs + TF_VariantTensor:$handle + ); +} + def TF_RankOp : TF_Op<"Rank", [NoSideEffect]> { let summary = "Returns the rank of a tensor."; @@ -7030,7 +7298,7 @@ tf.real(input) ==> [-2.25, 3.25] TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } -def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape]>, +def TF_RealDivOp : TF_Op<"RealDiv", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary]>, WithBroadcastableBinOpBuilder { let summary = "Returns x / y element-wise for real types."; @@ -7272,6 +7540,7 @@ reshape(t, []) ==> 7 }]; let hasCanonicalizer = 1; + let hasFolder = 1; } def TF_ResizeBilinearOp : TF_Op<"ResizeBilinear", [NoSideEffect]> { @@ -9309,6 +9578,33 @@ The outputs are a deterministic function of `shape` and `seed`. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } +def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect]> { + let summary = [{ +Outputs deterministic pseudorandom values from a truncated normal distribution. + }]; + + let description = [{ +The generated values follow a normal distribution with mean 0 and standard +deviation 1, except that values whose magnitude is more than 2 standard +deviations from the mean are dropped and re-picked. + +The outputs are a deterministic function of `shape` and `seed`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> { let summary = "Stops gradient computation."; @@ -9551,7 +9847,7 @@ Examples: TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } -def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>, WithBroadcastableBinOpBuilder { let summary = "Returns x - y element-wise."; @@ -9731,6 +10027,22 @@ For internal use only. ); } +def TF_TPUOrdinalSelectorOp : TF_Op<"TPUOrdinalSelector", []> { + let summary = "A TPU core selector Op."; + + let description = [{ +This Op produces a set of TPU cores (for warm-up) or a single TPU core +(for regular inference) to execute the TPU program on. The output is +consumed by TPUPartitionedCall. + }]; + + let arguments = (ins); + + let results = (outs + I32Tensor:$device_ordinals + ); +} + def TF_TPUReplicatedInputOp : TF_Op<"TPUReplicatedInput", [NoSideEffect]> { let summary = "Connects N inputs to an N-way replicated TPU computation."; @@ -10708,7 +11020,7 @@ Python Semantics. let hasCanonicalizer = 1; } -def TF_TruncateModOp : TF_Op<"TruncateMod", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_TruncateModOp : TF_Op<"TruncateMod", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = [{ Returns element-wise remainder of division. This emulates C semantics in that @@ -10734,7 +11046,7 @@ y + truncate_mod(x, y) = x`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_TruncatedNormalOp : TF_Op<"TruncatedNormal", []> { +def TF_TruncatedNormalOp : TF_Op<"TruncatedNormal", [TF_CannotDuplicate]> { let summary = "Outputs random values from a truncated normal distribution."; let description = [{ @@ -11181,7 +11493,7 @@ where(input) ==> [[0, 0, 0], TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if x == 0, and x / y otherwise, elementwise."; @@ -11547,7 +11859,7 @@ tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[ TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect, SameOperandsAndResultElementType]> { +def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = "Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise."; let arguments = (ins @@ -11562,7 +11874,7 @@ def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect, SameOperandsAndResultElementT TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, SameOperandsAndResultElementType]>, +def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."; @@ -11640,7 +11952,7 @@ create these operators. TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>; } -def TF__FusedMatMulOp : TF_Op<"_FusedMatMul", [NoSideEffect]> { +def TF__FusedMatMulOp : TF_Op<"_FusedMatMul", [NoSideEffect, SameOperandsAndResultElementType]> { let summary = [{ Performs a MatMul followed by a specified series of operations. }]; @@ -11666,9 +11978,9 @@ expected to create these operators. }]; let arguments = (ins - F32Tensor:$a, - F32Tensor:$b, - Variadic:$args, + TensorOf<[BF16, F32]>:$a, + TensorOf<[BF16, F32]>:$b, + Variadic>:$args, DefaultValuedAttr:$transpose_a, DefaultValuedAttr:$transpose_b, @@ -11677,31 +11989,13 @@ expected to create these operators. ); let results = (outs - F32Tensor:$product + TensorOf<[BF16, F32]>:$product ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>; } -def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> { - let summary = "A host-side computation called from a TPU device."; - - let arguments = (ins - Variadic:$inputs, - - StrAttr:$key, - DefaultValuedAttr:$tpu_core - ); - - let results = (outs - Variadic:$outputs - ); - - TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; - TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; -} - def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", []> { let summary = "An op that receives embeddng activations on the TPU."; @@ -11761,6 +12055,45 @@ used to look up the program in the compilation cache. TF_DerivedOperandSizeAttr NumDynamicShapes = TF_DerivedOperandSizeAttr<0>; } +def TF__TPUCompileMlirPlaceholderProgramKeyOp : TF_Op<"_TPUCompileMlirPlaceholderProgramKey", []> { + let summary = [{ +Placeholder program key (compilation cache key) of a _TPUCompileMlir `program`. + }]; + + let description = [{ +This op can be used when certain rewrite passes materialize ops that require a +program key but the _TPUCompileMlir op has not been added yet. Subsequent +rewrite passes must replace this op with a _TPUCompileMlir op `program` output. + }]; + + let arguments = (ins); + + let results = (outs + TF_StrTensor:$program + ); +} + +def TF__XlaHostComputeMlirOp : TF_Op<"_XlaHostComputeMlir", []> { + let summary = [{ +A pseudo-op to represent host-side computation in an XLA program. + }]; + + let arguments = (ins + Variadic:$inputs, + + StrAttr:$send_key, + StrAttr:$recv_key, + DefaultValuedAttr:$tpu_core + ); + + let results = (outs + Variadic:$outputs + ); + + TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Toutputs = TF_DerivedResultTypeListAttr<0>; +} + def TF__XlaRecvAtHostOp : TF_Op<"_XlaRecvAtHost", []> { let summary = [{ A placeholder op to receive values from a running XLA computation. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 17424b54fc2..1755c975c23 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -59,10 +59,27 @@ TODO: Make invariants more structured so that we can reference them in ops. def TF_OperandsSameAsResultsTypeOrRef : NativeOpTrait< "TF::OperandsSameAsResultsTypeOrRef">; +// Op has the same operand and result element types (or type itself, if scalar) +// after resolving reference types (i.e., after converting reference types to +// their corresponding TensorFlow or standard types). +def TF_SameOperandsAndResultElementTypeResolveRef : NativeOpTrait< + "TF::SameOperandsAndResultElementTypeResolveRef">; + // Layout agnostic operations do not depend on the operands data layout (data // format), as an example all element wise operations are layout agnostic. def TF_LayoutAgnostic : NativeOpTrait<"TF::LayoutAgnostic">; +// Trait to indicate operations that cannot be duplicated as they might carry +// certain state around within their implementations. +def TF_CannotDuplicate : NativeOpTrait<"TF::CannotDuplicate">; + +// Coefficient wise binary operation with implicit broadcasting support, for +// example tf.Sub operation. +def TF_CwiseBinary : NativeOpTrait<"TF::CwiseBinary">; + +// Coefficient wise unary operation, for example tf.Sqrt operation. +def TF_CwiseUnary : NativeOpTrait<"TF::CwiseUnary">; + // Variant of broadcastable trait that considers TF's subtype behavior. class TF_OpIsBroadcastableToRes : And<[ TCOpResIsShapedTypePred, @@ -332,7 +349,7 @@ class TF_DerivedOperandTypeListAttr : DerivedAttr< // This returns a list of shapes so it is used for variadic operands that // can have different shapes. class TF_DerivedOperandShapeListAttr : DerivedAttr< - "mlir::TF::OperandShapeRange", + "::mlir::TF::OperandShapeRange", "auto values = getODSOperands(" # idx # ");\n" "return {mlir::TF::OperandShapeIterator(values.begin()), " "mlir::TF::OperandShapeIterator(values.end())};", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc new file mode 100644 index 00000000000..ffcc9f7dd4f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc @@ -0,0 +1,22 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" + +namespace mlir { +namespace TF { +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc" +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index de6ce2d313a..dbad613d909 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -69,4363 +69,66 @@ limitations under the License. namespace mlir { namespace TF { -// Propagates underscore and device attributes from src to dst. -// TODO(b/158769932): This should be a general feature instead post some policy -// discussion. -static void PropagateAttributes(Operation *src, Operation *dst) { - auto device = mlir::Identifier::get("device", src->getContext()); - for (auto named_attr : src->getAttrs()) { - if (*named_attr.first.begin() == '_' || named_attr.first == device) - dst->setAttr(named_attr.first, named_attr.second); - } -} - -//===----------------------------------------------------------------------===// -// TF op helper functions -//===----------------------------------------------------------------------===// - -// Returns the RankedTensorType for the given operand. TensorFlow constant ops -// may have non-static shape because the shape is not propagated during constant -// folding. If the defining op for the given operand is a constant op, this -// routine uses the constant op's attribute to get the actual shape. -static RankedTensorType GetRankedTensorTypeForOperand(Value operand) { - DenseElementsAttr attr; - if (matchPattern(operand, m_Constant(&attr))) { - return attr.getType().dyn_cast(); - } - return operand.getType().dyn_cast(); -} - -// Returns true if the given `value` is of ranked float tensor type with the -// given `rank`. -static inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) { - return type && type.getRank() == rank && - type.getElementType().isa(); -} - -// Returns true if the given `value` has the specified rank or has unranked -// type. -static inline bool IsOfRankOrUnranked(Value value, int64_t rank) { - RankedTensorType type = GetRankedTensorTypeForOperand(value); - return !type || type.getRank() == rank; -} - -// Returns true if the given `value` has at least the specified rank or has -// unranked type. -static inline bool HasRankAtLeast(Value value, int64_t rank) { - RankedTensorType type = GetRankedTensorTypeForOperand(value); - return !type || type.getRank() >= rank; -} - -// Returns true if the given `value` has at most the specified rank or has -// unranked type. -static inline bool HasRankAtMost(Value value, int64_t rank) { - RankedTensorType type = GetRankedTensorTypeForOperand(value); - return !type || type.getRank() <= rank; -} - -static bool IsUnknownDimOrRank(int64_t dim_or_rank) { - return dim_or_rank == -1; -} - -// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If -// `incompatible_shape_error` is true, reports error if `x` and `y` has -// incompatible shapes. Otherwise, returns a tensor type with unknown rank. -static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, - Value y, BoolAttr incompatible_shape_error) { - auto result_type = - OpTrait::util::getBroadcastedType(x.getType(), y.getType()); - if (!result_type) { - if (incompatible_shape_error.getValue()) { - mlir::emitError(loc, "non-broadcastable operands"); - } else { - return UnrankedTensorType::get(builder->getI1Type()); - } - } - - auto ranked_type = result_type.dyn_cast(); - if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type()); - - return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type()); -} - -// Returns dimension index for the given TensorFlow axis that supports negative -// indexing. -static int64_t GetDimForAxis(int64_t axis, int64_t rank) { - return axis >= 0 ? axis : axis + rank; -} - -// Infers output type for reduction ops such as SumOp, MaxOp etc. -// TODO(b/e667204a): Move this logic to shape inference once it supports custom -// inference functions. -static Type InferReductionOpType(Value input, Value reduction_indices, - BoolAttr keep_dims, Builder *builder) { - Type input_ty = input.getType(); - Type element_ty = getElementTypeOrSelf(input_ty); - - // Output type is unranked if input type is not ranked. - auto ranked_ty = input_ty.dyn_cast(); - if (!ranked_ty) return UnrankedTensorType::get(element_ty); - int64_t rank = ranked_ty.getRank(); - - DenseIntElementsAttr indices; - if (!matchPattern(reduction_indices, m_Constant(&indices))) { - // Output type is unranked if reduction indices are not constant and reduced - // dimensions are not kept. - if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty); - - // Otherwise, output type has same rank as the input. - return RankedTensorType::get(SmallVector(rank, -1), element_ty); - } - - int64_t num_reduce_dim = 0; - llvm::SmallVector is_reduce_dim(rank, false); - for (const APInt &index : indices.getValues()) { - int64_t dim = GetDimForAxis(index.getSExtValue(), rank); - // Invalid input. - if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty); - - if (!is_reduce_dim[dim]) { - is_reduce_dim[dim] = true; - num_reduce_dim++; - } - } - - ArrayRef shape = ranked_ty.getShape(); - SmallVector out_shape; - out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim)); - for (int64_t i = 0; i < rank; ++i) { - if (!is_reduce_dim[i]) - out_shape.push_back(shape[i]); - else if (keep_dims.getValue()) - out_shape.push_back(1); - } - return RankedTensorType::get(out_shape, element_ty); -} - -// Verifies that the given types are cast compatible. If not, emits appropriate -// error for the given op. If mask_one_dim is set to true, then the types are -// allowed to have one mismatching dimension. Masking one of the dimensions is -// useful for ops like Concat that requires all ranked inputs to have the same -// rank and match dimension sizes for all but one of the dimensions. -static LogicalResult VerifyTypesCompatibility( - Operation::operand_type_range types, bool mask_one_dim, Operation *op) { - constexpr int64_t kUninitialized = -1; - int64_t common_rank = kUninitialized; - llvm::SmallVector common_dims; - int64_t dim_to_mask = kUninitialized; - - // Initialize common_rank with rank of the first ranked type and verify that - // following ranked types have the same rank. - // Similarly, initialize each of the dimensions with the first type that has - // the dimension size available and verify that all following types have the - // same size for the dimension. However, if mask_one_dim is true, note down - // the dimension index on the first mismatch and ignore dimension at that - // index in following types. - for (Type ty : types) { - RankedTensorType ranked_ty = ty.dyn_cast(); - if (!ranked_ty) continue; - - int64_t rank = ranked_ty.getRank(); - if (common_rank == kUninitialized) { - common_rank = rank; - common_dims.resize(common_rank, kUninitialized); - } else if (common_rank != rank) { - return op->emitError() - << "operand type " << ranked_ty - << " is not compatible with preceding operands; expected rank: " - << common_rank; - } - - for (int64_t i = 0, e = common_rank; i != e; i++) { - if (i == dim_to_mask) continue; - - int64_t dim = ranked_ty.getDimSize(i); - if (dim == kUninitialized) continue; - - int64_t &common_dim = common_dims[i]; - if (common_dim == kUninitialized) { - common_dim = dim; - } else if (common_dim != dim) { - // If mask_one_dim is true, do not emit an error if this is the only - // dimension with mismatches. Note down the dimension to mask it from - // the following types. - if (mask_one_dim && dim_to_mask == kUninitialized) { - dim_to_mask = i; - continue; - } - - return op->emitError() << "operand type " << ranked_ty - << " is not compatible with preceding operands; " - "expected dimension at index " - << i << ": " << common_dim; - } - } - } - return success(); -} - -// This is a helper for the Select to SelectV2 canonicalization. The `data` rank -// refers to the rank of `t`/`e` (these two inputs have equal rank; this is -// checked in the verifier). -// -// In most cases, the predicate for Select can be used directly as the predicate -// for SelectV2. However, there is one case that varies, which is when the -// predicate is a tensor and the data is multidimensional. In this case, Select -// op semantics dictate that the predicate tensor length must match the size of -// the first data dimension. This varies from normal broadcasting semantics -// (which are used in SelectV2), so we must reshape the tensor in this case to -// be compatible. -static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc, - Value cond, int data_rank) { - auto cond_tensor = cond.getType().cast(); - // Reshape is only needed in the case that the cond rank is 1 (i.e. it is - // a vector) AND t/e rank is > 1. - if (cond_tensor.getRank() != 1 || data_rank <= 1) { - // No reshape necessary. Leave cond as it is. - return cond; - } - - // This is the case where a reshape is needed. We want to construct the - // shape [x,1,...1], where x is the value in the pred tensor and the - // length of the shape is equal to data_rank. - SmallVector shape(data_rank, 1); - shape[0] = cond_tensor.getShape().front(); - auto new_shape_type = - RankedTensorType::get({data_rank}, builder->getIntegerType(64)); - auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape); - auto new_shape = builder->create(loc, shape_attr); - return builder->create(loc, cond, new_shape); -} - -//===----------------------------------------------------------------------===// -// Helper functions detect device capabilities from RuntimeDevices. -//===----------------------------------------------------------------------===// - -namespace { -using DeviceNameUtils = ::tensorflow::DeviceNameUtils; -using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName; - -bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) { - return device.type == ::tensorflow::DEVICE_GPU; -} - -} // namespace - -// Returns true if at least one GPU device is available at runtime. -bool CanUseGpuDevice(const RuntimeDevices &devices) { - return llvm::any_of(devices.device_names(), IsGpuDevice); -} - -// Returns true if all of the GPUs available at runtime support TensorCores -// (NVIDIA compute capability >= 7.0). -bool CanUseTensorCores(const RuntimeDevices &devices) { - auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) { - auto md = devices.GetGpuDeviceMetadata(device); - return md ? md->cc_major().getInt() >= 7 : false; - }; - return llvm::all_of( - llvm::make_filter_range(devices.device_names(), IsGpuDevice), - has_tensor_cores); -} - -// Returns true if operation does not have explicit device placement that would -// prevent it from running on GPU device. -bool CanUseGpuDevice(Operation *op) { - auto device_attr = op->getAttrOfType("device"); - if (!device_attr || device_attr.getValue().empty()) return true; - - DeviceNameUtils::ParsedName device; - if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device)) - return false; - - // We can't use GPU if operation explicitly placed on non-GPU device. - return !device.has_type || device.type == ::tensorflow::DEVICE_GPU; -} - -//===----------------------------------------------------------------------===// -// TF op helper functions to work with layout transformation. -//===----------------------------------------------------------------------===// - -SmallVector ReversePermutation(ArrayRef permutation) { - SmallVector reverse(permutation.size()); - for (size_t i = 0; i < permutation.size(); ++i) { - reverse[permutation[i]] = i; - } - return reverse; -} - -SmallVector GetDataFormatPermutation(StringRef from, StringRef to) { - if (from == "NHWC" && to == "NCHW") { - return {0, 3, 1, 2}; - } else if (from == "NCHW" && to == "NHWC") { - return {0, 2, 3, 1}; - } else { - return {}; - } -} - -// Shuffle elements in the `attr` according to the permutation. Optional -// `inner_size` allows to shuffle array attributes created from rank 2 tensors -// on outer dimension only. -ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef permutation, - int inner_size = 1) { - if (attr.size() == 0) return attr; - - assert(attr.size() % inner_size == 0); - assert(attr.size() / inner_size == permutation.size()); - - SmallVector values{attr.begin(), attr.end()}; - SmallVector shuffled(values.size()); - - for (size_t i = 0; i < permutation.size(); ++i) { - for (size_t j = 0; j < inner_size; ++j) { - shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j]; - } - } - - return ArrayAttr::get(shuffled, attr.getContext()); -} - -// Shuffle ranked tensor dimensions according to the permutation. -Type ShuffleRankedTensorType(Type type, ArrayRef permutation) { - if (auto ranked_type = type.dyn_cast()) { - ArrayRef shape = ranked_type.getShape(); - assert(permutation.size() == shape.size()); - - SmallVector new_shape(permutation.size()); - for (size_t i = 0; i < permutation.size(); ++i) - new_shape[i] = shape[permutation[i]]; - - return RankedTensorType::get(new_shape, ranked_type.getElementType()); - } - - return type; -} - -static bool AreCancellablePermutations(DenseIntElementsAttr perm0, - DenseIntElementsAttr perm1) { - if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false; - if (perm0.getNumElements() != perm1.getNumElements()) return false; - - SmallVector perm0_values; - for (const auto &value : perm0.getIntValues()) - perm0_values.push_back(value.getSExtValue()); - - SmallVector perm1_values; - for (const auto &value : perm1.getIntValues()) - perm1_values.push_back(value.getSExtValue()); - - for (int i = 0; i < perm0_values.size(); ++i) { - if (perm0_values[perm1_values[i]] != i) return false; - } - - return true; -} - -// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for -// layout sensitive operations that do not have any additional layout dependent -// attributes besides `data_format` string. -template -LogicalResult UpdateDataFormat(StringRef data_format, Op *op) { - auto perm = GetDataFormatPermutation(op->data_format(), data_format); - if (perm.empty()) return failure(); - - // Update data format attribute. - op->setAttr("data_format", StringAttr::get(data_format, op->getContext())); - - // Update types for all layout sensitive results. - auto layout_sensitive = cast(op->getOperation()); - for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) { - OpResult result = op->getOperation()->getResult(idx); - result.setType(ShuffleRankedTensorType(result.getType(), perm)); - } - - return success(); -} - -// Default implementation for folding operand transpose into the operation. -// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`. -template -LogicalResult FoldOperandsPermutation( - ArrayRef permutation, Op *op, - ArrayRef> shuffle_attrs = {}) { - MLIRContext *context = op->template getParentOfType().getContext(); - - // We only support NHWC <-> NCHW permutations. - static constexpr std::array kNchwToNhwc = {0, 2, 3, 1}; - static constexpr std::array kNhwcToNchw = {0, 3, 1, 2}; - - // Operation data format after folding `permutation`. - StringRef target_data_format = [&]() -> StringRef { - if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) { - return "NCHW"; // cancel NCHW->NHWC operand permutation - } else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) { - return "NHWC"; // cancel NHWC->NCHW operand permutation - } else { - return ""; - } - }(); - if (target_data_format.empty()) return failure(); - - // To fold operand `permutation` into the `op` we need shuffle all layout - // dependent attributes and types with a reverse permutation, and change - // operation data format to `target_data_format`. - // - // Example: - // %1 = SomeOp(...) {data_format = NHWC} - // %2 = Transpose(%1) {permutation = NHWC->NCHW} - // %3 = Op(%2) {data_format = NCHW} - // - // To bypass %2 we have to change data format to shuffle data format from NCHW - // to NHWC, which is the reverse of operand permutation (function argument). - auto reverse_permutation = - GetDataFormatPermutation(op->data_format(), target_data_format); - if (reverse_permutation.empty()) return failure(); - - op->setAttr("data_format", StringAttr::get(target_data_format, context)); - - for (auto pair : shuffle_attrs) { - StringRef attr_name = pair.first; - ArrayAttr attr_value = pair.second; - op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation)); - } - - auto fold = cast(op->getOperation()); - for (unsigned idx : fold.GetLayoutDependentResults()) { - OpResult result = op->getOperation()->getResult(idx); - result.setType( - ShuffleRankedTensorType(result.getType(), reverse_permutation)); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// Rewrite Pattern for removing trivial Arithmetic op. -//===----------------------------------------------------------------------===// - -namespace { -// Fold Arithmetic Op if one of the operands is a constant known to be an -// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if -// known identity value is either lhs or rhs. -template < - typename OpT, - typename std::enable_if::value>::type * = nullptr> -OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, - ArrayRef operands) { - auto lhs_type = arithmetic_op.x().getType().template cast(); - auto rhs_type = arithmetic_op.y().getType().template cast(); - auto result_type = - arithmetic_op.getResult().getType().template cast(); - - // We can fold arithmetic operation only of we can prove that we will not - // accidentally hide a broadcasting error. - auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty, - ShapedType result_ty) -> bool { - // Scalar identity is broadcastable to any operand shape, we only need to - // check that operand has the same shape as a result. - bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0; - if (scalar_identity) return operand_ty == result_ty; - - // If identity is not a scalar, we must verify that all shapes are equal - // and statically known. - // - // TODO(ezhulenev): Fold if identity shape is statically know to be - // broadcastable to the operand shape. - return operand_ty == result_ty && identity_ty == result_ty && - result_ty.hasStaticShape(); - }; - - // Check that we have a constant operand on one side (candidate for identity). - const bool is_commutative = - (std::is_same::value || std::is_same::value); - auto lhs_attr = operands[0].dyn_cast_or_null(); - auto rhs_attr = operands[1].dyn_cast_or_null(); - if (!rhs_attr && !(is_commutative && lhs_attr)) return {}; - - // Mul and Div ops have identity value one while AddV2 and SubOp have identity - // value zero. - const int identity = - (std::is_same::value || std::is_same::value || - std::is_same::value) - ? 1 - : 0; - - Type element_ty = lhs_type.getElementType(); - Attribute identity_attr; - if (auto ty = element_ty.template dyn_cast()) { - identity_attr = FloatAttr::get(ty, static_cast(identity)); - } else if (auto ty = element_ty.template dyn_cast()) { - identity_attr = IntegerAttr::get(ty, static_cast(identity)); - } else { - return {}; - } - - // Fold: Op(Operand, Identity) -> Operand. - if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) { - if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr) - return arithmetic_op.x(); - } - - // Fold: Op(Identity, Operand) -> Operand for commutative operations. - if (lhs_attr && is_commutative && - is_valid_broadcasting(rhs_type, lhs_type, result_type)) { - if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr) - return arithmetic_op.y(); - } - - return {}; -} -} // namespace - -namespace { -#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" -} // namespace - -//===----------------------------------------------------------------------===// -// AddOp -//===----------------------------------------------------------------------===// - -void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// AddNOp -//===----------------------------------------------------------------------===// - -OpFoldResult AddNOp::fold(ArrayRef operands) { - if (operands.size() == 1) return *inputs().begin(); - return {}; -} - -//===----------------------------------------------------------------------===// -// AddV2Op -//===----------------------------------------------------------------------===// - -void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -OpFoldResult AddV2Op::fold(ArrayRef operands) { - return IdentityArithmeticOpFolder(*this, operands); -} - -//===----------------------------------------------------------------------===// -// AllOp -//===----------------------------------------------------------------------===// - -// Verifies an reduction op's `input` and reduction `dims`. -static LogicalResult VerifyReductionInputAndDims(Value input, Value dims, - Location loc) { - auto dims_type = dims.getType().dyn_cast(); - if (!dims_type) return success(); - if (dims_type.getRank() > 1) - return emitError(loc, "dimensions can only be 0D or 1D tensor"); - - auto input_type = input.getType().dyn_cast(); - if (!input_type) return success(); - int64_t rank = input_type.getRank(); - - DenseIntElementsAttr dims_attr; - if (!matchPattern(dims, m_Constant(&dims_attr))) return success(); - for (const auto &dim_pair : llvm::enumerate(dims_attr)) { - int64_t cur_dim = dim_pair.value().getSExtValue(); - if (cur_dim < -rank || cur_dim >= rank) - return emitError(loc) - << dim_pair.index() << "-th dimension should be in the range of [-" - << rank << ", " << rank << ")"; - } - - return success(); -} - -static LogicalResult Verify(AllOp op) { - return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), - op.getLoc()); -} - -//===----------------------------------------------------------------------===// -// AnyOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(AnyOp op) { - return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), - op.getLoc()); -} - -//===----------------------------------------------------------------------===// -// AssertOp -//===----------------------------------------------------------------------===// - -namespace { - -// Removes Assert with constant true predicate. -struct AssertWithTrue : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(AssertOp op, - PatternRewriter &rewriter) const override { - ElementsAttr cst; - if (matchPattern(op.condition(), m_Constant(&cst))) { - if (cst.getValue({}).getValue()) { - rewriter.eraseOp(op); - return success(); - } - } - return failure(); - } -}; -} // namespace - -void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BatchMatMulOp -//===----------------------------------------------------------------------===// - -void BatchMatMulOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BatchMatMulV2Op -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(BatchMatMulV2Op op) { - if (!HasRankAtLeast(op.x(), 2)) { - return op.emitOpError("requires lhs operand to have rank at least two"); - } - if (!HasRankAtLeast(op.y(), 2)) { - return op.emitOpError("requires rhs operand to have rank at least two"); - } - return success(); -} - -void BatchMatMulV2Op::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BatchToSpaceOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(BatchToSpaceOp op) { - // Op already has a constraint that block_size >= 2. - int64_t block_size = op.block_size().getSExtValue(); - - llvm::SmallVector input_shape(4, ShapedType::kDynamicSize); - auto input_type = op.input().getType().cast(); - if (input_type.hasRank()) { - if (input_type.getRank() != 4) - return op.emitOpError() - << "requires input to be a 4D tensor, but got " << input_type; - - int64_t input_batch = input_type.getDimSize(0); - if (input_batch != ShapedType::kDynamicSize && - input_batch % (block_size * block_size) != 0) { - return op.emitOpError() - << "requires input batch (dimension 0) to be evenly divisible " - "by (block_size * block_size), but got input batch " - << input_batch << " and block_size " << block_size; - } - - input_shape.assign(input_type.getShape().begin(), - input_type.getShape().end()); - } - - auto crops_type = op.crops().getType().cast(); - if (crops_type.hasRank()) { - if (crops_type.getRank() != 2) - return op.emitOpError() - << "requires crops to be a 2D tensor, but got " << crops_type; - - auto dim_of_size = [&](int64_t dim, int64_t size) { - if (crops_type.isDynamicDim(dim)) return true; - return crops_type.getDimSize(dim) == size; - }; - if (!dim_of_size(0, 2) || !dim_of_size(1, 2)) - return op.emitOpError() - << "requires crops to be a tensor<2x2>, but got " << crops_type; - } - - DenseIntElementsAttr crops_attr; - // Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]], - // and flattened as [crop_top, crop_bottom, crop_left, crop_right] - llvm::SmallVector crops_values; - if (matchPattern(op.crops(), m_Constant(&crops_attr))) { - assert(crops_attr.getNumElements() == 4 && - "tf.BatchToSpace crops must have 4 elements"); - - auto crops_range = crops_attr.getIntValues(); - for (const auto &crops_value : crops_range) { - int64_t crops_value_int = crops_value.getSExtValue(); - if (crops_value_int < 0) - return op.emitOpError() - << "requires all crop values to be nonnegative, but got " - << crops_attr; - - crops_values.push_back(crops_value_int); - } - } - - auto output_type = op.output().getType().cast(); - if (output_type.hasRank()) { - if (output_type.getRank() != 4) - return op.emitOpError() - << "requires output to be a 4D tensor, but got " << output_type; - - auto static_dims = [](int64_t dim_a, int64_t dim_b) { - return dim_a != ShapedType::kDynamicSize && - dim_b != ShapedType::kDynamicSize; - }; - - auto output_shape = output_type.getShape(); - - // output batch = input batch / (block_size * block_size). - int64_t input_batch = input_shape[0]; - int64_t output_batch = output_shape[0]; - if (static_dims(input_batch, output_batch) && - (output_batch * block_size * block_size) != input_batch) - return op.emitOpError() - << "requires output batch (dimension 0) to be equal to input " - "batch (dimension 0) / (block_size * block_size), but got " - "output batch " - << output_batch << ", input batch " << input_batch - << ", and block_size " << block_size; - - auto check_spatial_dim = [&](int64_t spatial_dim_index, - llvm::StringRef dim_name, - llvm::StringRef crop_a_name, - llvm::StringRef crop_b_name) -> LogicalResult { - int64_t input_dim = input_shape[spatial_dim_index]; - int64_t output_dim = output_shape[spatial_dim_index]; - if (!static_dims(input_dim, output_dim)) return success(); - - int64_t input_dim_pad = input_dim * block_size; - // If crops are unknown, the maximum output spatial dim size is input - // spatial dim size * block_size, as crops can be minimum 0. - if (crops_values.empty() && output_dim > input_dim * block_size) - return op.emitOpError() - << "requires output " << dim_name << " (dimension " - << spatial_dim_index << ") to be less than or equal to input " - << dim_name << " (dimension " << spatial_dim_index - << ") * block_size, but got output " << dim_name << " " - << output_dim << ", input " << dim_name << " " << input_dim - << ", and block_size " << block_size; - - if (!crops_values.empty()) { - // output spatial dim = input spatial dim * block_size - crops. - int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)]; - int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1]; - if (output_dim != input_dim_pad - crop_a - crop_b) - return op.emitOpError() - << "requires output " << dim_name << " (dimension " - << spatial_dim_index << ") to be equal to input " << dim_name - << " (dimension " << spatial_dim_index << ") * block_size - " - << crop_a_name << " - " << crop_b_name << ", but got output " - << dim_name << " " << output_dim << ", input " << dim_name - << " " << input_dim << ", " << crop_a_name << " " << crop_a - << ", " << crop_b_name << " " << crop_b << ", and block_size " - << block_size; - } - - return success(); - }; - - if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) || - failed(check_spatial_dim(2, "width", "crop_left", "crop_right"))) - return failure(); - - int64_t input_depth = input_shape[3]; - int64_t output_depth = output_shape[3]; - if (static_dims(input_depth, output_depth) && output_depth != input_depth) - return op.emitOpError() - << "requires output depth (dimension 3) to be equal to input " - "depth (dimension 3), but got output depth " - << output_depth << " and input depth " << input_depth; - } - - return success(); -} - -void BatchToSpaceOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BiasAddOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// * the value and bias operands have valid ranks or are unranked. -// * Channel dimension of the value operand and length of bias matches if they -// are not unknown. -// -static LogicalResult Verify(BiasAddOp op) { - StringRef format = op.data_format(); - if (format == "NHWC") { - if (!HasRankAtLeast(op.value(), 2)) - return op.emitOpError( - "requires value operand to have rank at least two with `NHWC` data " - "format"); - } else { - // Op definition requires data_format to be either NHWC or NCHW. - DCHECK_EQ(format.str(), "NCHW"); - if (!HasRankAtLeast(op.value(), 3)) - return op.emitOpError( - "requires value operand to have rank at least three with `NCHW` data " - "format"); - } - - if (!IsOfRankOrUnranked(op.bias(), 1)) - return op.emitOpError("requires bias operand to have rank exactly one"); - - RankedTensorType value_ty = op.value().getType().dyn_cast(); - RankedTensorType bias_ty = op.bias().getType().dyn_cast(); - if (!bias_ty || !value_ty) return success(); - - // TODO(hinsu): Leverage tensor_format.h utility in TensorFlow to compute - // dimension indices based on format. - int64_t feature_dim_idx = format == "NHWC" ? value_ty.getRank() - 1 : 1; - int64_t feature_dim = value_ty.getDimSize(feature_dim_idx); - int64_t bias_len = bias_ty.getDimSize(0); - if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) { - return op.emitOpError() - << "requires channel dimension and feature dimension to match; " - "found " - << feature_dim << " and " << bias_len << ", respectively"; - } - return success(); -} - -//===----------------------------------------------------------------------===// -// BiasAddGradOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// * the out_backprop operands have valid ranks or are unranked. -// -static LogicalResult Verify(BiasAddGradOp op) { - StringRef format = op.data_format(); - if (format == "NHWC") { - if (!HasRankAtLeast(op.out_backprop(), 2)) - return op.emitOpError( - "requires out_backprop operand to have rank at least two with `NHWC` " - "data format"); - } else { - // Op definition requires data_format to be either NHWC or NCHW. - DCHECK_EQ(format.str(), "NCHW"); - if (!HasRankAtLeast(op.out_backprop(), 3)) - return op.emitOpError( - "requires out_backprop operand to have rank at least three with " - "`NCHW` data format"); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// BiasAddV1Op -//===----------------------------------------------------------------------===// - -void BiasAddV1Op::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BitcastOp -//===----------------------------------------------------------------------===// - -void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// BroadcastToOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(BroadcastToOp op) { - // TODO(antiagainst): check that - // * The 'shape' input is an 1-D int tensor. - // * Each dimension pair of the source and target shapes are either equal - // or one of them is one. - return success(); -} - -//===----------------------------------------------------------------------===// -// CaseOp -//===----------------------------------------------------------------------===// - -class FoldConstantCaseOp : public OpRewritePattern { - public: - explicit FoldConstantCaseOp(MLIRContext *context) - : OpRewritePattern(context) {} - LogicalResult matchAndRewrite(TF::CaseOp op, - PatternRewriter &rewriter) const override; -}; - -LogicalResult FoldConstantCaseOp::matchAndRewrite( - TF::CaseOp op, PatternRewriter &rewriter) const { - // Extract the constant cond value. - DenseIntElementsAttr branch; - if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure(); - - // Only attempt to fold scalar valued case statements. - // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. - if (!branch.getType().cast().getShape().empty()) - return failure(); - - int index = *branch.getValues().begin(); - // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. - if (index >= op.branches().size()) return failure(); - - auto func = op.branches()[index].cast(); - auto empty = rewriter.getStringAttr(""); - auto call_op = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, - /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); - PropagateAttributes(op.getOperation(), call_op); - rewriter.replaceOp(op, call_op.getResults()); - return success(); -} - -void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// CastOp -//===----------------------------------------------------------------------===// - -OpFoldResult CastOp::fold(ArrayRef operands) { - // Cast with the same type is a no-op. - Value operand = getOperand(); - if (getType() == operand.getType()) return operand; - return {}; -} - -//===----------------------------------------------------------------------===// -// ConcatOp and ConcatV2Op -//===----------------------------------------------------------------------===// - -template ::value>::type * = nullptr> -static LogicalResult Verify(OpT op) { - // TODO(hinsu): Convert variadic length attributes to derived attributes. - Operation::operand_range values = op.values(); - - int axis_idx = std::is_same() ? 0 : 1; - Value axis = *op.getODSOperands(axis_idx).begin(); - if (!HasRankAtMost(axis, 1)) { - return op.emitOpError( - "requires axis to be of scalar type (or vector type for older " - "versions)"); - } - - return VerifyTypesCompatibility(values, - /*mask_one_dim=*/true, op.getOperation()); -} - -void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// ConcatOffsetOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(ConcatOffsetOp op) { - if (op.N() < 2) - return op.emitOpError() << "requires N to be at least 2, got " << op.N(); - - if (op.shape().size() != op.offset().size()) - return op.emitOpError() - << "requires sizes of shapes and offsets to be the same, got sizes " - << op.shape().size() << " and " << op.offset().size(); - - auto ranked_dim = op.concat_dim().getType().dyn_cast(); - if (ranked_dim && ranked_dim.getRank() != 0) - return op.emitOpError() - << "requires concat_dim to be a scalar, got tensor of rank " - << ranked_dim.getRank(); - - int64_t num_dims = -1; - for (auto shape_offset_idx : - llvm::enumerate(llvm::zip(op.shape(), op.offset()))) { - Value shape = std::get<0>(shape_offset_idx.value()); - Value offset = std::get<1>(shape_offset_idx.value()); - const size_t idx = shape_offset_idx.index(); - - if (failed(verifyCompatibleShape(shape.getType(), offset.getType()))) - return op.emitOpError() << "requires operand and result " << idx - << " to have compatible shapes"; - - auto ranked_shape = shape.getType().dyn_cast(); - if (!ranked_shape) continue; - - if (ranked_shape.getRank() != 1) - return op.emitOpError() << "requires shape tensor operand " << idx - << " to be of rank 1, got tensor of rank " - << ranked_shape.getRank(); - - if (!ranked_shape.hasStaticShape()) continue; - - int64_t ranked_shape_dim = ranked_shape.getDimSize(0); - if (num_dims == -1) - num_dims = ranked_shape_dim; - else if (ranked_shape_dim != num_dims) - return op.emitOpError() - << "requires shape tensor (rank 1) operand " << idx - << " to be of length " << num_dims - << ", got tensor (rank 1) of length " << ranked_shape_dim; - } - - return success(); -} - -LogicalResult ConcatOffsetOp::fold(ArrayRef operands, - SmallVectorImpl &results) { - // ConcatOffset must have its first operand be concat_dim and at least two - // shape tensors in variadic shapes operand. - if (operands.size() < 3) return failure(); - - // Check concat_dim is a scalar. - auto concat_dim_attr = operands[0].dyn_cast_or_null(); - if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0) - return failure(); - - llvm::SmallVector shapes; - shapes.reserve(operands.size() - 1); - for (Attribute shape : llvm::drop_begin(operands, 1)) - if (auto shape_attr = shape.dyn_cast_or_null()) - shapes.push_back(shape_attr); - else - return failure(); - - // Check all shapes are vectors of the same length. - if (shapes.front().getType().getRank() != 1) return success(); - const int64_t num_dims = shapes.front().getNumElements(); - for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) - if (shape.getType().getRank() != 1 || shape.getNumElements() != num_dims) - return failure(); - - // Check concat_dim is within [-num_dims, num_dims). - int32_t concat_dim = (*concat_dim_attr.getValues().begin()); - if (concat_dim < 0) concat_dim += num_dims; - if (concat_dim >= num_dims || concat_dim < 0) return failure(); - - // Check all elements besides at concat_dim match across all shape tensors. - SmallVector shape0; - shape0.reserve(num_dims); - for (int32_t dim : shapes.front().getValues()) shape0.push_back(dim); - - for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) { - for (auto dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) { - if (dims_and_idx.index() == concat_dim) continue; - - if (std::get<0>(dims_and_idx.value()) != - std::get<1>(dims_and_idx.value()).getSExtValue()) - return failure(); - } - } - - // Compute an exclusive cumulative sum of elements at concat_dim. - results.reserve(shapes.size()); - SmallVector cumulative_sum(num_dims, 0); - RankedTensorType offset_type = - RankedTensorType::get({num_dims}, IntegerType::get(32, getContext())); - for (DenseIntElementsAttr shape : shapes) { - results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum)); - cumulative_sum[concat_dim] += shape.getValue(concat_dim); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// ConjOp -//===----------------------------------------------------------------------===// - -void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// ConstOp -//===----------------------------------------------------------------------===// - -OpFoldResult ConstOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); - - // Return the held attribute value. - return value(); -} - -// Builds a constant op with the specified attribute `value`. The result -// op's type is deduced from `value`; if `value` is of scalar type, -// wraps it up with a tensor type of empty shape. -// TODO(jpienaar): This one differs from the autogenerated one as it takes an -// attribute but always creates an ElementsAttr internally. -void ConstOp::build(OpBuilder &builder, OperationState &result, - Attribute value) { - ShapedType type; - if (auto elem_attr = value.dyn_cast()) { - return ConstOp::build(builder, result, elem_attr); - } else if (value.isa()) { - // All TensorFlow types must be tensor types. In the build() method, - // we want to provide more flexibility by allowing attributes of scalar - // types. But we need to wrap it up with ElementsAttr to construct - // valid TensorFlow constants. - type = RankedTensorType::get(/*shape=*/{}, value.getType()); - return ConstOp::build(builder, result, DenseElementsAttr::get(type, value)); - } - // TODO(jpienaar): support other TensorFlow specific types. - llvm_unreachable("unsupported attribute type for building tf.Const"); -} - -void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, - Attribute value) { - // Handle the case where the type and value are already tensors. - if (type.isa() && value.isa()) { - result.addTypes(type); - result.addAttribute("value", value); - return; - } - - // Otherwise, default to the attribute builder. - ConstOp::build(builder, result, value); - assert(type == result.types[0] && "type mismatch in construction"); -} - -LogicalResult ConstOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - auto value = attributes.get("value"); - if (!value) return emitOptionalError(location, "missing attribute 'value'"); - if (auto elem_attr = value.dyn_cast()) { - inferredReturnTypes.assign({elem_attr.getType()}); - return success(); - } - return emitOptionalError(location, - "attribute 'value' failed to satisfy constraint: " - "constant vector/tensor"); -} - -//===----------------------------------------------------------------------===// -// Conv2DOp and Conv3DOp -//===----------------------------------------------------------------------===// - -template -static LogicalResult VerifyConvOpAttributes(OpT op, int num_dims) { - if (!IsOfRankOrUnranked(op.getResult(), num_dims)) - return op.emitOpError() - << "requires result to be " << num_dims << "D tensor"; - - auto is_not_positive = [](Attribute val) { - return val.cast().getValue().getSExtValue() <= 0; - }; - - int64_t strides_size = op.strides().size(); - if (strides_size != num_dims) - return op.emitOpError() << "requires strides attribute length to be " - << num_dims << "; actual length " << strides_size; - if (llvm::any_of(op.strides().getValue(), is_not_positive)) - return op.emitOpError("requires positive strides"); - - int64_t dilations_size = op.strides().size(); - if (op.dilations().size() != num_dims) - return op.emitOpError() << "requires dilations attribute length to be " - << num_dims << "; actual length " << dilations_size; - if (llvm::any_of(op.dilations().getValue(), is_not_positive)) - return op.emitOpError("requires positive dilations"); - - return success(); -} - -// Verifies that, -// * Ranks of operands and result are valid -// * Number of input channels is divisible by the number of filter input -// channels -// * Length of explicit_paddings attribute is valid and has non negative -// elements -// * strides and dilations attributes have positive elements -template ::value>::type * = nullptr> -static LogicalResult Verify(OpT op) { - int num_spatial_dims = std::is_same() ? 2 : 3; - int num_dims = 2 + num_spatial_dims; - - if (!IsOfRankOrUnranked(op.input(), num_dims) || - !IsOfRankOrUnranked(op.filter(), num_dims)) - return op.emitOpError() - << "requires operands to be " << num_dims << "D tensor"; - - // EXPLICIT padding mode and the associated attribute is limited to Conv2D. - // So, fetch attribute by string instead of the op.explicit_paddings() - // attribute getter. - if (op.padding() == "EXPLICIT") { - auto paddings = op.template getAttrOfType("explicit_paddings"); - if (!paddings) - return op.emitOpError() << "requires attribute 'explicit_paddings' with " - "'EXPLICIT' padding mode"; - - int64_t paddings_size = paddings.size(); - int64_t expected_size = 2 * num_dims; - - if (paddings_size != expected_size) - return op.emitOpError() - << "requires explicit_paddings attribute length to be " - << expected_size << "; actual length " << paddings_size; - - auto is_negative = [](Attribute val) { - return val.cast().getValue().getSExtValue() < 0; - }; - if (llvm::any_of(paddings.getValue(), is_negative)) - return op.emitOpError("requires non negative explicit paddings"); - } - - LogicalResult verify_result = VerifyConvOpAttributes(op, num_dims); - if (failed(verify_result)) { - return verify_result; - } - - int64_t input_channels = -1; - if (auto ty = op.input().getType().template dyn_cast()) { - std::string data_format = op.data_format().str(); - tensorflow::TensorFormat format; - auto is_valid = FormatFromString(data_format, &format); - DCHECK(is_valid) << data_format; - int idx = tensorflow::GetTensorFeatureDimIndex(num_dims, format); - input_channels = ty.getDimSize(idx); - } - - int64_t filter_channels = -1; - if (auto ty = op.filter().getType().template dyn_cast()) { - int idx = tensorflow::GetFilterTensorInputChannelsDimIndex( - num_dims, tensorflow::FORMAT_HWIO); - filter_channels = ty.getDimSize(idx); - } - - if (input_channels != -1 && filter_channels != -1 && - input_channels % filter_channels != 0) - return op.emitOpError() - << "requires the number of input channels to be divisible by the " - "number of filter input channels; found " - << input_channels << " and " << filter_channels << ", respectively"; - - return success(); -} - -LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) { - auto perm = GetDataFormatPermutation(this->data_format(), data_format); - if (perm.empty()) return failure(); - - // Update data_format attribute and result types. - if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); - - // Update convolution attributes. - setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - setAttr("strides", ShuffleArrayAttr(strides(), perm)); - setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); - - return success(); -} - -StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { - // Keep current data format if no GPUs are available or if explicit placement - // does not allow to use GPU for this operation. - if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); - - // Input must be a tensor. - auto input_ty = input().getType().dyn_cast(); - if (!input_ty) return data_format(); - - // For f16 data type on devices with Tensor Cores support NHWC data format - // is up to ~2x faster. - const bool is_f16 = input_ty.getElementType().isF16(); - if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; - - // For f32/f16 data type decision depends on the filter size in spatial - // dimensions, for other data types we keep current data format. - if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16()) - return data_format(); - - // Keep current data format if filter rank is unknown or not equal to 4. - auto filter_ty = filter().getType().dyn_cast(); - if (!filter_ty || filter_ty.getRank() != 4) return data_format(); - - const int64_t d0 = filter_ty.getDimSize(0); - const int64_t d1 = filter_ty.getDimSize(1); - - auto all_ones = [](ArrayAttr arr) -> bool { - return llvm::all_of(arr, [](Attribute attr) -> bool { - return attr.cast().getInt() == 1; - }); - }; - - // Convolutions with 1x1 filter and with strides and dilations all ones, can - // be computed as a GEMM in NHWC data format, and can be up to ~2x times - // faster than convolution in NCHW. - const bool one_by_one = d0 == 1 && d1 == 1; - const bool trivial_strides = all_ones(strides()); - const bool trivial_dilations = all_ones(dilations()); - - // TODO(ezhulenev): This might lead to excessive transposes in the final IR, - // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1. - // Also FusedBatchNorm in training mode prefers NCHW data format. Check if all - // users can efficiently use NHWC data format? - if (one_by_one && trivial_strides && trivial_dilations) { - return "NHWC"; - } - - // If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because - // it's the fastest option on NVIDIA GPUs with cuDNN library support. - return "NCHW"; -} - -//===----------------------------------------------------------------------===// -// Conv2dBackpropFilterOp -//===----------------------------------------------------------------------===// - -LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { - StringRef src_data_format = this->data_format(); - - auto perm = GetDataFormatPermutation(src_data_format, data_format); - if (perm.empty()) return failure(); - - // Update data_format attribute and result types. - if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); - - // Update convolution attributes. - setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - setAttr("strides", ShuffleArrayAttr(strides(), perm)); - setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); - - // Permute filter sizes operand. - OpBuilder builder(getOperation()); - auto filter_sizes_permuted = builder.create( - getLoc(), filter_sizes(), StringAttr::get(src_data_format, getContext()), - StringAttr::get(data_format, getContext())); - setOperand(1, filter_sizes_permuted); - - return success(); -} - -StringRef Conv2DBackpropFilterOp::GetOptimalLayout( - const RuntimeDevices &devices) { - // Keep current data format if no GPUs are available or if explicit placement - // does not allow to use GPU for this operation. - if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); - - // Input must be a tensor. - auto input_ty = input().getType().dyn_cast(); - if (!input_ty) return data_format(); - - // For f16 data type on devices with Tensor Cores support NHWC data format - // is up to ~2x faster. - const bool is_f16 = input_ty.getElementType().isF16(); - if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; - - // Otherwise always use "NCHW". - return "NCHW"; -} - -//===----------------------------------------------------------------------===// -// Conv2DBackpropInputOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(Conv2DBackpropInputOp op) { - int num_spatial_dims = 2; - int num_dims = 2 + num_spatial_dims; - - if (!IsOfRankOrUnranked(op.out_backprop(), num_dims) || - !IsOfRankOrUnranked(op.filter(), num_dims)) - return op.emitOpError() - << "requires operands to be " << num_dims << "D tensor"; - - LogicalResult verify_result = VerifyConvOpAttributes(op, num_dims); - if (failed(verify_result)) { - return verify_result; - } - - return success(); -} - -LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) { - StringRef src_data_format = this->data_format(); - - auto perm = GetDataFormatPermutation(src_data_format, data_format); - if (perm.empty()) return failure(); - - // Update data_format attribute and result types. - if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); - - // Update convolution attributes. - setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); - setAttr("strides", ShuffleArrayAttr(strides(), perm)); - setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); - - // Permute input sizes operand. - OpBuilder builder(getOperation()); - auto input_sizes_permuted = builder.create( - getLoc(), input_sizes(), StringAttr::get(src_data_format, getContext()), - StringAttr::get(data_format, getContext())); - setOperand(0, input_sizes_permuted); - - return success(); -} - -StringRef Conv2DBackpropInputOp::GetOptimalLayout( - const RuntimeDevices &devices) { - // Keep current data format if no GPUs are available or if explicit placement - // does not allow to use GPU for this operation. - if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); - - // Filter must be a tensor. - auto filter_ty = filter().getType().dyn_cast(); - if (!filter_ty) return data_format(); - - // For f16 data type on devices with Tensor Cores support NHWC data format - // is up to ~2x faster. - const bool is_f16 = filter_ty.getElementType().isF16(); - if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; - - // Otherwise always use "NCHW". - return "NCHW"; -} - -//===----------------------------------------------------------------------===// -// DataFormatVecPermuteOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(DataFormatVecPermuteOp op) { - auto input_ty = op.x().getType().dyn_cast(); - if (!input_ty) return success(); - - int rank = input_ty.getRank(); - if (rank != 1 && rank != 2) - return op.emitOpError("requires input of rank 1 or 2"); - - if (rank == 1) { - int64_t dim0 = input_ty.getDimSize(0); - if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2) - return op.emitOpError("requires 1D input of size 4 or size 2"); - } - - if (rank == 2) { - int64_t dim0 = input_ty.getDimSize(0); - if (dim0 != ShapedType::kDynamicSize && dim0 != 4) - return op.emitOpError( - "requires first dimensions of 2D input to be of size 4"); - - int64_t dim1 = input_ty.getDimSize(1); - if (dim1 != ShapedType::kDynamicSize && dim1 != 2) - return op.emitOpError( - "requires second dimensions of 2D input to be of size 2"); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// DivOp -//===----------------------------------------------------------------------===// - -void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -OpFoldResult DivOp::fold(ArrayRef operands) { - return IdentityArithmeticOpFolder(*this, operands); -} - -//===----------------------------------------------------------------------===// -// DynamicStitchOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(DynamicStitchOp op) { - if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1"); - - if (RankedTensorType out_ty = op.getType().dyn_cast()) { - if (out_ty.getRank() == 0) { - return op.emitOpError("requires non scalar output"); - } - } - - llvm::SmallDenseSet index_values; - bool all_indices_const = true; - int32_t max_index = -1; - llvm::Optional> inferred_item_shape; - for (auto it : llvm::zip(op.indices(), op.data())) { - Value index = std::get<0>(it); - - DenseIntElementsAttr index_attr; - if (matchPattern(index, m_Constant(&index_attr))) { - for (int32_t index : index_attr.getValues()) { - if (index < 0) - return op.emitOpError() - << "requires non-negative index values; found " << index; - max_index = std::max(index, max_index); - index_values.insert(index); - } - } else { - all_indices_const = false; - } - - Value data = std::get<1>(it); - RankedTensorType index_ty = index.getType().dyn_cast(); - RankedTensorType data_ty = data.getType().dyn_cast(); - if (!index_ty || !data_ty) continue; - - int64_t index_rank = index_ty.getRank(); - ArrayRef data_shape = data_ty.getShape(); - ArrayRef index_shape = index_ty.getShape(); - if (failed(mlir::verifyCompatibleShape(index_shape, - data_shape.take_front(index_rank)))) - return op.emitOpError() << "requires shape of data with type " << data_ty - << " to have prefix matching with shape of the " - "corresponding index type " - << index_ty; - - ArrayRef item_shape = data_shape.drop_front(index_rank); - if (!inferred_item_shape) { - inferred_item_shape = llvm::to_vector<4>(item_shape); - continue; - } - - if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape))) - return op.emitOpError() << "has inconsistent shaped data and index " - "pairs; inferred item shapes [" - << llvm::makeArrayRef(*inferred_item_shape) - << "] and [" << item_shape << "] don't match"; - for (int i = 0, e = item_shape.size(); i < e; ++i) { - int64_t &inferred_dim = (*inferred_item_shape)[i]; - int64_t dim = item_shape[i]; - if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim; - } - } - - // If all indices are constants, then verify that they cover all indices in - // the range [0, max_index] and the output type is legal. - if (all_indices_const) { - for (int32_t i = 0; i <= max_index; i++) { - if (!index_values.count(i)) - return op.emitOpError() << "missing index " << i; - } - - if (inferred_item_shape) { - SmallVector expected_shape; - expected_shape.push_back(max_index + 1); - expected_shape.append(inferred_item_shape->begin(), - inferred_item_shape->end()); - - auto out_ty = op.getType().cast(); - auto expected_out_ty = - RankedTensorType::get(expected_shape, out_ty.getElementType()); - - if (!AreCastCompatible({out_ty, expected_out_ty})) { - return op.emitOpError() << "has invalid output type; should be " - "compatible with inferred type " - << expected_out_ty; - } - } - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// EinsumOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// * Arity of the op is at most two. -// -// TODO(hinsu): Verify einsum equation attribute. -static LogicalResult Verify(EinsumOp op) { - if (op.N() > 2) { - return op.emitOpError("supports at most two operands"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// EmptyOp -//===----------------------------------------------------------------------===// - -OpFoldResult EmptyOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "empty op has one operand"); - - Attribute attr = operands.front(); - if (!attr) return {}; - - auto int_attr = attr.cast(); - SmallVector out_shape; - for (const auto val : int_attr.getValues()) { - out_shape.push_back(val); - } - - auto type = getResult().getType().cast(); - auto etype = type.getElementType(); - - // We can not fold if the result is not static. - if (!type.hasStaticShape()) return {}; - - if (auto float_type = etype.dyn_cast()) { - auto out_type = RankedTensorType::get(out_shape, float_type); - return DenseElementsAttr::get(out_type, - {APFloat(float_type.getFloatSemantics())}); - } - - if (auto int_type = etype.dyn_cast()) { - auto out_type = RankedTensorType::get(out_shape, etype); - APInt val(int_type.getWidth(), 0, int_type.getSignedness()); - return DenseElementsAttr::get(out_type, val); - } - - return {}; -} - -//===----------------------------------------------------------------------===// -// EmptyTensorListOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(EmptyTensorListOp op) { - if (!IsOfRankOrUnranked(op.element_shape(), 0) && - !IsOfRankOrUnranked(op.element_shape(), 1)) { - return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); - } - - if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) { - return op.emitOpError("requires max_num_elements operand to be 0D tensor"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// EqualOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(EqualOp op) { - // If we allow inputs to have incompatible type, then nothing to do. - if (!op.incompatible_shape_error()) return success(); - - // Otherwise, check inputs are broadcastable. - return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( - op.getOperation()); -} - -void EqualOp::build(OpBuilder &builder, OperationState &result, Value x, - Value y, BoolAttr incompatible_shape_error) { - auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, - incompatible_shape_error); - return build(builder, result, result_type, x, y, incompatible_shape_error); -} - -//===----------------------------------------------------------------------===// -// ExpandDimsOp -//===----------------------------------------------------------------------===// - -Type InferExpandDimsOpType(Value input, Value dim) { - Type element_ty = input.getType().cast().getElementType(); - auto unranked_ty = UnrankedTensorType::get(element_ty); - - auto input_ty = input.getType().dyn_cast(); - if (!input_ty) return unranked_ty; - - DenseIntElementsAttr dim_attr; - if (!matchPattern(dim, m_Constant(&dim_attr)) || - dim_attr.getNumElements() != 1) - return unranked_ty; - int64_t dim_val = (*dim_attr.begin()).getSExtValue(); - int64_t input_rank = input_ty.getRank(); - - if (dim_val < -input_rank - 1 || dim_val > input_rank + 1) return unranked_ty; - if (dim_val < 0) dim_val += input_rank + 1; - - SmallVector shape = llvm::to_vector<4>(input_ty.getShape()); - shape.insert(shape.begin() + dim_val, 1); - return RankedTensorType::get(shape, element_ty); -} - -void ExpandDimsOp::build(OpBuilder &builder, OperationState &result, - Value input, Value dim) { - return build(builder, result, InferExpandDimsOpType(input, dim), input, dim); -} - -//===----------------------------------------------------------------------===// -// FakeQuantWithMinMaxArgsOp -//===----------------------------------------------------------------------===// -static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) { - // TODO(fengliuai): moving the following to an utility method. - const llvm::fltSemantics &semantics = op.min().getSemantics(); - float rmin, rmax; - if (&semantics == &APFloat::IEEEsingle()) { - rmin = op.min().convertToFloat(); - rmax = op.max().convertToFloat(); - } else { - rmin = op.min().convertToDouble(); - rmax = op.max().convertToDouble(); - } - // Range boundaries must be valid. - if (rmin >= rmax) { - return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) + - "," + Twine(std::to_string(rmax)) + "]"); - } - int64_t num_bits = op.num_bits().getSExtValue(); - if (num_bits < 2 || num_bits > 16) { - return op.emitOpError( - "requires num_bits to be between 2 and 16, inclusive"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// FakeQuantWithMinMaxVarsOp -//===----------------------------------------------------------------------===// -static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) { - auto min = GetRankedTensorTypeForOperand(op.min()); - if (min && !IsOfRankedFloatTensorType(min, 0)) - return op.emitOpError("requires min to be a 0d float tensor"); - - auto max = GetRankedTensorTypeForOperand(op.max()); - if (max && !IsOfRankedFloatTensorType(max, 0)) - return op.emitOpError("requires max to be a 0d float tensor"); - - int64_t num_bits = op.num_bits().getSExtValue(); - if (num_bits < 2 || num_bits > 16) { - return op.emitOpError( - "requires num_bits to be between 2 and 16, inclusive"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// FakeQuantWithMinMaxVarsPerChannelOp -//===----------------------------------------------------------------------===// -static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) { - auto min = GetRankedTensorTypeForOperand(op.min()); - if (min && !IsOfRankedFloatTensorType(min, 1)) - return op.emitOpError("requires min to be a 1d float tensor"); - - auto max = GetRankedTensorTypeForOperand(op.max()); - if (max && !IsOfRankedFloatTensorType(max, 1)) - return op.emitOpError("requires max to be a 1d float tensor"); - - Value inputs = op.inputs(); - if (!HasRankAtLeast(inputs, 1)) - return op.emitError("requires inputs to be at least 1d float tensor"); - - int64_t num_bits = op.num_bits().getSExtValue(); - if (num_bits < 2 || num_bits > 16) { - return op.emitOpError( - "requires num_bits to be between 2 and 16, inclusive"); - } - - auto inputs_type = inputs.getType().dyn_cast(); - if (!inputs_type) return success(); - int depth = inputs_type.getDimSize(inputs_type.getRank() - 1); - if ((min && min.getDimSize(0) != depth) || - (max && max.getDimSize(0) != depth)) { - return op.emitOpError( - "requires min and max to have same size as last dimension of inputs"); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// FillOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(FillOp op) { - if (!IsOfRankOrUnranked(op.dims(), 1)) - return op.emitOpError() << "requires dims to be a 1D tensor"; - if (!IsOfRankOrUnranked(op.value(), 0)) - return op.emitOpError() << "requires value to be a scalar"; - - return success(); -} - -static ShapedType InferFillOpType(Value dims, Value value) { - Type etype = value.getType().cast().getElementType(); - - DenseIntElementsAttr dims_attr; - if (!matchPattern(dims, m_Constant(&dims_attr))) { - return UnrankedTensorType::get(etype); - } - - llvm::SmallVector shape; - shape.reserve(dims_attr.getNumElements()); - for (const APInt dim : dims_attr.getValues()) { - shape.push_back(dim.getSExtValue()); - } - return RankedTensorType::get(shape, etype); -} - -void FillOp::build(OpBuilder &builder, OperationState &result, Value dims, - Value value) { - FillOp::build(builder, result, InferFillOpType(dims, value), dims, value); -} - -OpFoldResult FillOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "fill op has two operand"); - - auto type = getType().cast(); - // DenseElementsAttr that is used in this folder only supports int and float - // types. - // TODO(hinsu): Handle complex types once there is a attribute kind for - // complex. - if (!type.getElementType().isIntOrFloat()) return {}; - - auto value = operands[1].dyn_cast_or_null(); - if (!value) return {}; - - if (type.hasStaticShape()) - return DenseElementsAttr::get(type, value.getValue({})); - - auto dims = operands[0].dyn_cast_or_null(); - if (!dims) return {}; - - llvm::SmallVector shape; - shape.reserve(dims.getNumElements()); - for (const APInt dim : dims.getValues()) { - shape.push_back(dim.getSExtValue()); - } - type = RankedTensorType::get(shape, type.getElementType()); - - return DenseElementsAttr::get(type, value.getValue({})); -} - -//===----------------------------------------------------------------------===// -// FusedBatchNormGradOp -//===----------------------------------------------------------------------===// - -// TODO(b/150954845): Add benchmarks to verify that layout preference didn't -// change in the latest GPU generations. - -LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) { - return ::mlir::TF::UpdateDataFormat(data_format, this); -} - -StringRef FusedBatchNormGradV3Op::GetOptimalLayout( - const RuntimeDevices &devices) { - // Keep current data format if no GPUs are available or if explicit placement - // does not allow to use GPU for this operation. - if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); - - // For f16 data type on devices with Tensor Cores support NHWC data format - // is up to ~2x faster. - auto x_ty = x().getType().cast(); - const bool is_f16 = x_ty.getElementType().isF16(); - if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; - - // For all other data types prefer NCHW. - return "NCHW"; -} - -//===----------------------------------------------------------------------===// -// FusedBatchNormOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(FusedBatchNormOp op) { - auto x = GetRankedTensorTypeForOperand(op.x()); - if (x && !IsOfRankedFloatTensorType(x, 4)) - return op.emitOpError("requires x to be a 4D float tensor"); - - auto scale = GetRankedTensorTypeForOperand(op.scale()); - if (scale && !IsOfRankedFloatTensorType(scale, 1)) - return op.emitOpError("requires scale to be a 1D float tensor"); - - auto offset = GetRankedTensorTypeForOperand(op.offset()); - if (offset && !IsOfRankedFloatTensorType(offset, 1)) - return op.emitOpError("requires offset to be a 1D float tensor"); - - auto mean = GetRankedTensorTypeForOperand(op.mean()); - if (mean && !IsOfRankedFloatTensorType(mean, 1)) - return op.emitOpError("requires mean to be a 1D float tensor"); - - auto variance = GetRankedTensorTypeForOperand(op.variance()); - if (variance && !IsOfRankedFloatTensorType(variance, 1)) - return op.emitOpError("requires variance to be a 1D float tensor"); - - // TODO(antiagainst): check attributes - - return success(); -} - -LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation( - ArrayRef permutation) { - // FusedBatchNorm in training mode is a layout sentitive operation, and should - // have already assigned an optimal data format. - if (is_training()) return failure(); - - return ::mlir::TF::FoldOperandsPermutation(permutation, this); -} - -LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) { - return ::mlir::TF::UpdateDataFormat(data_format, this); -} - -StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) { - // In inference mode FusedBatchNorm is not sensitive to data layout. - if (!is_training()) return data_format(); - - // Keep current data format if no GPUs are available or if explicit placement - // does not allow to use GPU for this operation. - if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) - return data_format(); - - // For f16 data type on devices with Tensor Cores support NHWC data format - // is up to ~2x faster. - auto x_ty = x().getType().cast(); - const bool is_f16 = x_ty.getElementType().isF16(); - if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; - - // For all other data types prefer NCHW. - return "NCHW"; -} - -//===----------------------------------------------------------------------===// -// GatherV2Op -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(GatherV2Op op) { - int64_t batch_dims = op.batch_dims().getSExtValue(); - if (auto ty = op.indices().getType().dyn_cast()) { - int64_t rank = ty.getRank(); - if (batch_dims > rank || batch_dims < -rank) - return op.emitOpError() - << "batch_dims (" << batch_dims << ") must be in range [" << -rank - << ", " << rank + 1 << ")"; - if (batch_dims < 0) batch_dims += rank; - } - - if (!HasRankAtMost(op.axis(), 1)) - return op.emitOpError("requires axis to have rank at most 1"); - - DenseIntElementsAttr axis_attr; - if (matchPattern(op.axis(), m_Constant(&axis_attr))) { - int64_t axis = (*axis_attr.begin()).getSExtValue(); - if (auto ty = op.params().getType().dyn_cast()) { - int64_t rank = ty.getRank(); - if (axis >= rank || axis < -rank) - return op.emitOpError() << "axis (" << axis << ") must be in range [" - << -rank << ", " << rank << ")"; - if (axis < 0) axis += rank; - } - - if (batch_dims >= 0 && axis >= 0 && axis < batch_dims) { - return op.emitOpError() << "requires axis (" << axis - << ") to be greater than or equal to batch_dims (" - << batch_dims << ")"; - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// IfOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(IfOp op) { - auto module = op.getParentOfType(); - auto then_fn = module.lookupSymbol(op.then_branch()); - if (!then_fn) - return op.emitOpError("then_branch refers to an undefined function : ") - << op.then_branch(); - auto else_fn = module.lookupSymbol(op.else_branch()); - if (!else_fn) - return op.emitOpError("else_branch refers to an undefined function : ") - << op.else_branch(); - auto then_fn_type = then_fn.getType(); - auto else_fn_type = else_fn.getType(); - - // Non-conditional operands starting with the second operand are passed to - // branches and should be pair-wise compatible with branches' inputs. - unsigned expected_num_inputs = op.getNumOperands() - 1; - if (then_fn_type.getNumInputs() != expected_num_inputs || - else_fn_type.getNumInputs() != expected_num_inputs) - return op.emitError("branches should have " + Twine(expected_num_inputs) + - " inputs"); - - for (unsigned i = 0; i < expected_num_inputs; ++i) { - auto operand_type = op.getOperand(i + 1).getType().cast(); - auto then_input_type = then_fn_type.getInput(i).cast(); - if (!AreCastCompatible({operand_type, then_input_type})) - return op.emitError( - llvm::formatv("then branch input type {0} is incompatible with " - "operand type {1} at index {2}", - then_input_type, operand_type, i)); - - auto else_input_type = else_fn_type.getInput(i).cast(); - if (!AreCastCompatible({operand_type, else_input_type})) - return op.emitError( - llvm::formatv("else branch input type {0} is incompatible with " - "operand type {1} at index {2}", - else_input_type, operand_type, i)); - - // If branches have incompatible input types that means that no tensor can - // serve as input to both the functions. Hence, the op is invalid. - if (!AreCastCompatible({then_input_type, else_input_type})) - return op.emitError(llvm::formatv( - "branches inputs have incompatible types {0} and {1} at index {2}", - then_input_type, else_input_type, i)); - } - - // Branches' results should be pair-wise compatible with the op results. - unsigned expected_num_results = op.getNumResults(); - if (then_fn_type.getNumResults() != expected_num_results || - else_fn_type.getNumResults() != expected_num_results) - return op.emitError("branches should have " + Twine(expected_num_results) + - " results"); - - for (unsigned i = 0; i < expected_num_results; ++i) { - auto result_type = op.getResult(i).getType().cast(); - auto then_result_type = then_fn_type.getResult(i).cast(); - if (!AreCastCompatible({then_result_type, result_type})) - return op.emitError( - llvm::formatv("then branch result type {0} is incompatible with op " - "result type {1} at index {2}", - then_result_type, result_type, i)); - - auto else_result_type = else_fn_type.getResult(i).cast(); - if (!AreCastCompatible({else_result_type, result_type})) - return op.emitError( - llvm::formatv("else branch result type {0} is incompatible with op " - "result type {1} at index {2}", - else_result_type, result_type, i)); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// IfRegionOp -//===----------------------------------------------------------------------===// - -LogicalResult VerifyRegionResults(Operation *op, Region ®ion, - StringRef region_name) { - auto op_name = op->getName().getStringRef(); - // verify that op outputs match yield inputs - YieldOp yield = cast(region.front().getTerminator()); - unsigned expected_num_results = op->getNumResults(); - if (yield.getNumOperands() != expected_num_results) - return op->emitOpError() - << region_name + " should have same number (" << expected_num_results - << ") of results as " << op_name << " but has " - << yield.getNumOperands() << " results"; - - for (int idx : llvm::seq(0, expected_num_results)) { - auto op_result_type = op->getResult(idx).getType().cast(); - auto region_result_type = - yield.getOperand(idx).getType().cast(); - if (!AreCastCompatible({region_result_type, op_result_type})) - return op->emitError(llvm::formatv( - "{0} result type {1} is incompatible with {2} " - "result type {3} at index {4}", - region_name, region_result_type, op_name, op_result_type, idx)); - } - return success(); -} - -static LogicalResult Verify(IfRegionOp op) { - if (failed(VerifyRegionResults(op, op.then_branch(), "then"))) - return failure(); - if (failed(VerifyRegionResults(op, op.else_branch(), "else"))) - return failure(); - return success(); -} - -//===----------------------------------------------------------------------===// -// InvertOp -//===----------------------------------------------------------------------===// - -void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// InvertPermutationOp -//===----------------------------------------------------------------------===// - -// Verifies that the input is 1D. -static LogicalResult Verify(InvertPermutationOp op) { - auto x_type = op.x().getType().cast(); - if (!x_type.hasRank()) return success(); - if (x_type.getShape().size() != 1) - return op.emitOpError() << "requires input x to be 1-dimensional"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// LeakyReluOp -//===----------------------------------------------------------------------===// - -OpFoldResult LeakyReluOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "leaky relu has one operand"); - - // leaky_relu(x, alpha: 1) -> x - if (alpha().convertToFloat() == 1.0f) return getOperand(); - - auto calculate = [&](FloatAttr arg) { - APFloat val = arg.getValue(); - if (val.isNegative()) val = alpha() * val; - return FloatAttr::get(arg.getType(), val); - }; - - if (auto arg = operands[0].dyn_cast_or_null()) { - return calculate(arg); - } else if (auto arg = operands[0].dyn_cast_or_null()) { - if (auto elementAttr = arg.getSplatValue().dyn_cast()) - return DenseElementsAttr::get(arg.getType(), calculate(elementAttr)); - } - return {}; -} - -//===----------------------------------------------------------------------===// -// LogOp -//===----------------------------------------------------------------------===// - -void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// ReadVariableOp -//===----------------------------------------------------------------------===// - -void ReadVariableOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// VarIsInitializedOp -//===----------------------------------------------------------------------===// - -namespace { - -/// Erase VarIsInitializedOp operations with no uses. This op has side effect on -/// resources (read-only), but can still be deleted if it has zero uses. -struct EraseDeadVarIsInitializedOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(VarIsInitializedOp op, - PatternRewriter &rewriter) const override { - if (!op.use_empty()) return failure(); - rewriter.eraseOp(op); - return success(); - } -}; -} // end anonymous namespace. - -void VarIsInitializedOp::getCanonicalizationPatterns( - OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); -} - -//===----------------------------------------------------------------------===// -// LogicalNotOp -//===----------------------------------------------------------------------===// - -void LogicalNotOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// MatrixBandPartOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(MatrixBandPartOp op) { - if (!HasRankAtLeast(op.input(), 2)) { - return op.emitOpError() - << "requires `input` to have rank of at least 2, but found " - << op.input().getType(); - } - if (!IsOfRankOrUnranked(op.num_lower(), 0)) { - return op.emitOpError() - << "requires `num_lower` to have 0 dimensions, but found " - << op.num_lower().getType(); - } - if (!IsOfRankOrUnranked(op.num_upper(), 0)) { - return op.emitOpError() - << "requires `num_upper` to have 0 dimensions, but found " - << op.num_upper().getType(); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// MaxOp -//===----------------------------------------------------------------------===// - -void MaxOp::build(OpBuilder &builder, OperationState &result, Value input, - Value reduction_indices, BoolAttr keep_dims) { - Type out_ty = - InferReductionOpType(input, reduction_indices, keep_dims, &builder); - build(builder, result, out_ty, input, reduction_indices, keep_dims); -} - -//===----------------------------------------------------------------------===// -// MaxPoolOp -//===----------------------------------------------------------------------===// - -LogicalResult MaxPoolOp::FoldOperandsPermutation( - ArrayRef permutation) { - return ::mlir::TF::FoldOperandsPermutation( - permutation, this, {{"strides", strides()}, {"ksize", ksize()}}); -} - -//===----------------------------------------------------------------------===// -// MaxPoolGradOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(MaxPoolGradOp op) { - if (!IsOfRankOrUnranked(op.orig_input(), 4)) { - return op.emitOpError() << "requires orig_input to be rank 4"; - } - if (!IsOfRankOrUnranked(op.orig_output(), 4)) { - return op.emitOpError() << "requires orig_output to be rank 4"; - } - if (!IsOfRankOrUnranked(op.grad(), 4)) { - return op.emitOpError() << "requires grad to be rank 4"; - } - return success(); -} - -//===----------------------------------------------------------------------===// -// MeanOp -//===----------------------------------------------------------------------===// - -LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { - // Reduction indices must be defined by a constant operation. - auto reduction_op = - dyn_cast_or_null(reduction_indices().getDefiningOp()); - if (!reduction_op) return failure(); - - auto reductions_value = reduction_op.value().dyn_cast(); - if (!reductions_value) return failure(); - - // Prepare new reduction indices according to operand permutation. - SmallVector shuffled_reduction; - llvm::transform(reductions_value.getIntValues(), - std::back_inserter(shuffled_reduction), - [&](APInt idx) { return permutation[idx.getSExtValue()]; }); - - // Add constant operation with a new reduction indices. - OpBuilder builder(getOperation()); - auto type = mlir::RankedTensorType::get(shuffled_reduction.size(), - builder.getIntegerType(32)); - auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction); - auto shuffled_reduction_op = builder.create(getLoc(), values); - - // Use new reduction indices. - setOperand(1, shuffled_reduction_op); - - return success(); -} - -//===----------------------------------------------------------------------===// -// MulOp -//===----------------------------------------------------------------------===// - -OpFoldResult MulOp::fold(ArrayRef operands) { - return IdentityArithmeticOpFolder(*this, operands); -} - -//===----------------------------------------------------------------------===// -// NegOp -//===----------------------------------------------------------------------===// - -void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// NotEqualOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(NotEqualOp op) { - // If we allow inputs to have incompatible type, then nothing to do. - if (!op.incompatible_shape_error()) return success(); - - // Otherwise, check inputs are broadcastable. - return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( - op.getOperation()); -} - -void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x, - Value y, BoolAttr incompatible_shape_error) { - auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, - incompatible_shape_error); - return build(builder, result, result_type, x, y, incompatible_shape_error); -} - -//===----------------------------------------------------------------------===// -// OneHotOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(OneHotOp op) { - int64_t axis = op.axis().getSExtValue(); - - auto indices_ty = op.indices().getType().dyn_cast(); - if (indices_ty && - !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) { - return op.emitOpError() - << "expected axis (" << axis << ") to be -1 or between [0, " - << indices_ty.getShape().size() << "]"; - } - - if (axis < -1) { - return op.emitOpError() << "expected axis (" << axis - << ") to be -1 or between [0, rank(indices()))"; - } - - if (!IsOfRankOrUnranked(op.depth(), 0)) { - return op.emitOpError() << "requires depth to be a scalar"; - } - if (!IsOfRankOrUnranked(op.on_value(), 0)) { - return op.emitOpError() << "requires on_value to be a scalar"; - } - if (!IsOfRankOrUnranked(op.off_value(), 0)) { - return op.emitOpError() << "requires off_value to be a scalar"; - } - - DenseIntElementsAttr depth_attr; - if (matchPattern(op.depth(), m_Constant(&depth_attr))) { - if (depth_attr.getType().getRank() != 0) - return op.emitOpError() << "requires depth to be a scalar"; - int64_t depth = depth_attr.getValue({}).getSExtValue(); - if (depth < 0) { - return op.emitOpError() << "depth must be non-negative, got: " << depth; - } - } - - return success(); -} - -static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value, - Value off_value, IntegerAttr axis) { - int64_t axis_val = axis.getInt(); - Type element_ty = on_value.getType().cast().getElementType(); - auto unranked_ty = UnrankedTensorType::get(element_ty); - if (axis_val < -1) return unranked_ty; - - auto indices_ty = indices.getType().dyn_cast(); - if (!indices_ty) return unranked_ty; - - auto shape = llvm::to_vector<2>(indices_ty.getShape()); - if (axis_val == -1) axis_val = shape.size(); - - int64_t depth_val = ShapedType::kDynamicSize; - DenseIntElementsAttr depth_attr; - if (matchPattern(depth, m_Constant(&depth_attr)) && - depth_attr.getNumElements() == 1) - depth_val = (*depth_attr.begin()).getSExtValue(); - shape.insert(shape.begin() + axis_val, depth_val); - return RankedTensorType::get(shape, element_ty); -} - -void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices, - Value depth, Value on_value, Value off_value, - IntegerAttr axis) { - build(builder, result, - InferOneHotOpType(indices, depth, on_value, off_value, axis), indices, - depth, on_value, off_value, axis); -} - -//===----------------------------------------------------------------------===// -// PackOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(PackOp op) { - // TODO(hinsu): Convert variadic length attributes to derived attributes. - Operation::operand_range values = op.values(); - - if (failed(VerifyTypesCompatibility(values, - /*mask_one_dim=*/false, - op.getOperation()))) { - return failure(); - } - - int64_t inputs_rank = -1; - for (Value value : values) { - if (auto ty = value.getType().dyn_cast()) { - // Exit early as input types are verified to be compatible so all ranked - // tensors have the same rank. - inputs_rank = ty.getRank(); - break; - } - } - if (inputs_rank == -1) return success(); - - // The values can be packed along any of the dimensions between 0 and - // inputs rank, inclusive. Also, as the negative axis values wrap around so - // the axis value range is [-(R+1), R+1). - int64_t range_begin = -inputs_rank - 1; // Inclusive - int64_t range_end = inputs_rank + 1; // Exclusive - int64_t axis = op.axis().getSExtValue(); - if (axis < range_begin || axis >= range_end) { - return op.emitError() << "attribute 'axis' should be within range [" - << range_begin << ", " << range_end - << "); actual value: " << axis; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// PadOp -//===----------------------------------------------------------------------===// - -LogicalResult PadOp::FoldOperandsPermutation(ArrayRef permutation) { - // Paddings must be defined by a constant operation. - auto paddings_op = dyn_cast_or_null(paddings().getDefiningOp()); - if (!paddings_op) return failure(); - - auto paddings_value = paddings_op.value().dyn_cast(); - if (!paddings_value || - paddings_value.getNumElements() != permutation.size() * 2) - return failure(); - - SmallVector shuffled_paddings(paddings_value.getNumElements()); - for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) { - size_t outer_idx = index_pair.index() / 2; - size_t inner_idx = index_pair.index() % 2; - - shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] = - index_pair.value().getSExtValue(); - } - - // Add constant operation with a new paddings. - OpBuilder builder(getOperation()); - auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(), - builder.getIntegerType(32)); - auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings); - auto shuffled_paddings_op = builder.create(getLoc(), values); - - // Use new paddings. - setOperand(1, shuffled_paddings_op); - - // Change the result type. - getResult().setType(ShuffleRankedTensorType(getResult().getType(), - ReversePermutation(permutation))); - - return success(); -} - -//===----------------------------------------------------------------------===// -// ParseExampleV2Op -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(ParseExampleV2Op op) { - // NOTE(mrry): This validates properties of an op that would previously be - // validated by the TensorFlow OpDef type checker. In addition to these - // checks, the shape inference function for ParseExampleV2 validates the - // consistency of the argument and result types. - - // Validate dense variadic input and output lengths. - // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we - // do not need to validate dense_defaults. - auto dense_types_count = - std::distance(op.Tdense().begin(), op.Tdense().end()); - auto dense_values_count = - std::distance(op.dense_values().begin(), op.dense_values().end()); - if (dense_values_count != dense_types_count) { - return op.emitError() << "output 'dense_values' should have same length " - << "as attribute 'Tdense'"; - } - - // Validate sparse variadic output lengths. - // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we - // do not need to validate sparse_values. - auto sparse_types_count = - std::distance(op.sparse_types().begin(), op.sparse_types().end()); - if (op.num_sparse() != sparse_types_count) { - return op.emitError() << "attribute 'num_sparse' should be the same as " - << "the length of attribute 'sparse_types'"; - } - if (op.sparse_indices().size() != sparse_types_count) { - return op.emitError() << "output 'sparse_indices' should have same length " - << "as attribute 'sparse_types'"; - } - if (op.sparse_shapes().size() != sparse_types_count) { - return op.emitError() << "output 'sparse_shapes' should have same length " - << "as attribute 'sparse_types'"; - } - - // Validate ragged variadic output lengths. - auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(), - op.ragged_value_types().end()); - auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(), - op.ragged_split_types().end()); - if (ragged_value_types_count != ragged_split_types_count) { - return op.emitError() << "attribute 'ragged_value_types' should have same " - << "length as attribute 'ragged_split_types'"; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// PartitionedCallOp -//===----------------------------------------------------------------------===// - -template -static LogicalResult VerifyPartitionedCall(OpClass op) { - auto module = op.template getParentOfType(); - SymbolRefAttr func = op.getAttr("f").template cast(); - - auto function = - dyn_cast_or_null(SymbolTable::lookupSymbolIn(module, func)); - - if (!function) { - return op.emitError("'f' attribute refers to an undefined function: ") - << func; - } - - FunctionType function_ty = function.getType(); - int func_arg_count = function_ty.getNumInputs(); - int arg_count = op.args().size(); - - if (arg_count != func_arg_count) { - return op.emitError() << "argument count mismatch: 'args' has " << arg_count - << " arguments, but '" << func << "' expects " - << func_arg_count; - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// PowOp -//===----------------------------------------------------------------------===// - -OpFoldResult PowOp::fold(ArrayRef operands) { - auto constant_y = operands[1].dyn_cast_or_null(); - if (constant_y && constant_y.isSplat()) { - APFloat y_value = constant_y.getSplatValue(); - auto output_type = getType().cast(); - if (y_value.isZero() && output_type.hasStaticShape()) { - return DenseElementsAttr::get( - output_type, - FloatAttr::get(output_type.getElementType(), /*value=*/1.0)); - } - if (y_value.isExactlyValue(1.0)) { - return x(); - } - } - return {}; -} - -//===----------------------------------------------------------------------===// -// QrOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// -// * Input type, if ranked, must have at least 2 dimensions and at most -// INT32_MAX dimensions. -// -static LogicalResult Verify(QrOp op) { - auto ttype = op.input().getType().cast(); - if (!ttype.hasRank()) return success(); - if (!HasRankAtLeast(op.input(), 2)) - return op.emitOpError( - "requires ranked input tensor to be of rank 2 or more"); - if (!HasRankAtMost(op.input(), std::numeric_limits::max())) - return op.emitOpError( - "requires ranked input tensor to be of rank INT32_MAX or less"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// ReciprocalOp -//===----------------------------------------------------------------------===// - -void ReciprocalOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// RandomUniformOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(RandomUniformOp op) { - if (!IsOfRankOrUnranked(op.shape(), 1)) - return op.emitOpError("shape must be 1D tensor"); - return success(); -} - -//===----------------------------------------------------------------------===// -// RangeOp -//===----------------------------------------------------------------------===// - -void RangeOp::build(OpBuilder &builder, OperationState &result, Value start, - Value limit, Value delta) { - assert(start.getType() == limit.getType()); - assert(start.getType() == delta.getType()); - DenseIntElementsAttr start_val; - DenseIntElementsAttr limit_val; - DenseIntElementsAttr delta_val; - if (matchPattern(start, m_Constant(&start_val)) && - matchPattern(limit, m_Constant(&limit_val)) && - matchPattern(delta, m_Constant(&delta_val))) { - auto size = llvm::APIntOps::RoundingSDiv( - *limit_val.begin() - *start_val.begin(), *delta_val.begin(), - llvm::APInt::Rounding::DOWN); - return RangeOp::build( - builder, result, - RankedTensorType::get( - size.getSExtValue(), - start.getType().cast().getElementType()), - start, limit, delta); - } - return RangeOp::build( - builder, result, - RankedTensorType::get( - {-1}, start.getType().cast().getElementType()), - start, limit, delta); -} -//===----------------------------------------------------------------------===// -// RankOp -//===----------------------------------------------------------------------===// - -void RankOp::build(OpBuilder &builder, OperationState &result, Value input) { - return RankOp::build(builder, result, - RankedTensorType::get({}, builder.getIntegerType(32)), - input); -} - -// This will create a constant value for RankOp of a ranked tensor. -OpFoldResult RankOp::fold(ArrayRef operands) { - auto type = input().getType(); - auto ranked_type = type.dyn_cast(); - if (!ranked_type) return {}; - - auto output_type = getType().cast(); - int32_t rank = ranked_type.getRank(); - return DenseIntElementsAttr::get(output_type, rank); -} - -//===----------------------------------------------------------------------===// -// RealDivOp -//===----------------------------------------------------------------------===// - -void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -OpFoldResult RealDivOp::fold(ArrayRef operands) { - return IdentityArithmeticOpFolder(*this, operands); -} - -//===----------------------------------------------------------------------===// -// ReshapeOp -//===----------------------------------------------------------------------===// - -// TODO(b/128020684): Verify the output type. -static LogicalResult Verify(ReshapeOp op) { - auto shape_type = op.shape().getType().cast(); - if (!shape_type.hasRank()) return success(); - if (shape_type.getRank() != 1) - return op.emitOpError("shape must be 1D tensor"); - auto rank_by_shape = shape_type.getShape()[0]; - auto type_of_tensor = op.tensor().getType().cast(); - // No compile time verification for unknown sized shape. - if (rank_by_shape == -1 || !type_of_tensor.hasStaticShape()) return success(); - int64_t num_by_tensor = type_of_tensor.getNumElements(); - - auto out_ty = op.getType().dyn_cast(); - if (out_ty && out_ty.hasStaticShape()) { - int64_t num_output_elements = out_ty.getNumElements(); - if (num_by_tensor != num_output_elements) - return op.emitOpError() - << "number of output elements (" << num_output_elements - << ") does not match expected number of elements (" - << num_by_tensor << ")"; - } - - // Check values if constant shape. No compiling time verification for - // non-constant shape. - auto *shape_op = op.shape().getDefiningOp(); - if (!shape_op) return success(); - Attribute shape_cst; - if (!matchPattern(shape_op, m_Constant(&shape_cst))) return success(); - auto shape_cst_attr = shape_cst.dyn_cast(); - if (!shape_cst_attr) return op.emitOpError("shape must be a valid tensor"); - - if (auto opaque_attr = shape_cst_attr.dyn_cast()) { - opaque_attr.decode(shape_cst_attr); - } - - // We know the shape is a 1-D Tensor, then let us get the number of - // elements it implies. - unsigned num_by_shape = 1; - unsigned unknown_dim_count = 0; - for (int i = 0, e = rank_by_shape; i != e; ++i) { - auto num = shape_cst_attr.getValue(i).getInt(); - // The dimension size value can be -1, and that the real size needs to - // be computed so that the total size remains constant. At most one - // component of shape can be -1. - if (num == -1) { - if (++unknown_dim_count > 1) { - return op.emitOpError("more than one component of shape are -1"); - } - } else { - num_by_shape *= num; - } - } - // If there is one component of shape is -1, the dimension should be - // computed so that the total size remains constant. - if (unknown_dim_count == 1) { - if (num_by_tensor % num_by_shape != 0) - return op.emitOpError( - "one component of shape is -1 but couldn't infer the dimension"); - return success(); - } - // If the elements by the tensor and implies by the shape don't match, - // fail this static check. - if (num_by_tensor != num_by_shape) { - return op.emitOpError( - "mismatch in tensor elements and shape implied elements"); - } - return success(); -} - -void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, - Value shape) { - auto ttype = tensor.getType().cast(); - auto etype = ttype.getElementType(); - - auto unranked = [&builder, etype, &result, shape, tensor]() { - return ReshapeOp::build(builder, result, UnrankedTensorType::get(etype), - tensor, shape); - }; - - // If tensor is unranked then we have no info about output of shape. - if (!ttype.hasRank()) return unranked(); - - DenseIntElementsAttr attr_shape; - if (matchPattern(shape, m_Constant(&attr_shape))) { - llvm::SmallVector const_shape; - const_shape.reserve(attr_shape.getNumElements()); - - // Detect if reshape output shape is folded. - bool flatten = false; - int unknown_index = -1; - // The product of constant shape argument excluding unknown dimension. - int64_t product_cshape = 1; - for (auto e : llvm::enumerate(attr_shape)) { - int64_t val = e.value().getSExtValue(); - if (IsUnknownDimOrRank(val)) { - if (flatten) { - mlir::emitError(result.location) - << "only one unknown dimension allowed"; - return; - } - flatten = true; - unknown_index = e.index(); - } else { - product_cshape *= val; - } - const_shape.push_back(val); - } - - // Compute the value of the unknown dimension. - if (flatten) { - // Compute number of elements in tensor shape. - auto tshape = ttype.getShape(); - int64_t product_tshape = std::accumulate(tshape.begin(), tshape.end(), 1, - std::multiplies()); - // Set the unknown dimension such that total number of elements remain - // constant. - // Note: The case where the ratio is not integral, and so the total size - // of reshape not constant, is checked in verify function. - const_shape[unknown_index] = product_tshape / product_cshape; - } - return ReshapeOp::build(builder, result, - RankedTensorType::get(const_shape, etype), tensor, - shape); - } - return unranked(); -} - -void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// SelectOp -//===----------------------------------------------------------------------===// - -void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -// Verifies a few extra requirements on SelectOp: -// (1) `then` and `else` must have same shape -// (2) At least one of the following must be true: -// (a) `cond` has the same rank as `then` and `else` -// (b) `cond` is a scalar -// (c) `cond` is a vector AND `then` and `else` are non-scalar with their -// first dimension equal to `cond`. -static LogicalResult Verify(SelectOp op) { - auto then_tensor = op.t().getType().cast(); - auto else_tensor = op.e().getType().cast(); - // Check (1). - if (!AreCastCompatible({then_tensor, else_tensor})) - return op.emitOpError() << "requires t and e have compatible shapes"; - - // Get data rank (if exists). - int data_rank; - // If data is unranked or data_rank is 0, this will remain -2. Otherwise - // refers to first dimension of then and/or else. - int data_first_dim = -2; - bool then_has_rank = then_tensor.hasRank(); - bool else_has_rank = else_tensor.hasRank(); - if (then_has_rank && else_has_rank) { - data_rank = then_tensor.getRank(); - if (then_tensor.getRank() > 0) - data_first_dim = then_tensor.getShape().front(); - if (else_tensor.getRank() > 0) - data_first_dim = std::max( - static_cast(else_tensor.getShape().front()), data_first_dim); - } else if (then_has_rank) { - data_rank = then_tensor.getRank(); - if (then_tensor.getRank() > 0) - data_first_dim = then_tensor.getShape().front(); - } else if (else_has_rank) { - data_rank = else_tensor.getRank(); - if (else_tensor.getRank() > 0) - data_first_dim = else_tensor.getShape().front(); - } else { - // Neither has a rank. - return success(); - } - - auto cond_tensor = op.condition().getType().dyn_cast(); - if (!cond_tensor) return success(); - auto cond_rank = cond_tensor.getRank(); - // Check (2a) and (2b). - if (cond_rank == 0 || cond_rank == data_rank) return success(); - // Check (2c). - if (cond_rank == 1) { - auto cond_shape = cond_tensor.getShape().front(); - if (data_rank == 0) { - return op.emitOpError() - << "requires that t and e are nonscalar when pred is a vector"; - } - // We know `data` tensor has a rank of at least 1. - if (data_first_dim != -1 && cond_shape != -1 && - data_first_dim != cond_shape) { - return op.emitOpError() << "requires that, when pred is a vector, the " - "shape matches the first dimension of t and e"; - } - return success(); - } - // None of (2a,b,c) were true; fail. - return op.emitOpError() << "requires that pred is a scalar OR has the same " - "rank as t and e OR is a vector"; -} - -//===----------------------------------------------------------------------===// -// SelectV2Op -//===----------------------------------------------------------------------===// - -static Type InferSelectV2OpType(Value condition, Value e, Value t) { - Type element_ty = e.getType().cast().getElementType(); - auto unranked_ty = UnrankedTensorType::get(element_ty); - - Type broadcasted_ty = - OpTrait::util::getBroadcastedType(e.getType(), t.getType()); - if (!broadcasted_ty) return unranked_ty; - - auto cond_ranked_ty = condition.getType().dyn_cast(); - auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast(); - if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty; - - // Explicitly get broadcasted output type as element types of condition may - // not be same as the broadcated type's element type. - SmallVector result_shape; - if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(), - broadcasted_ranked_ty.getShape(), - result_shape)) - return unranked_ty; - return RankedTensorType::get(result_shape, element_ty); -} - -void SelectV2Op::build(OpBuilder &builder, OperationState &result, - Value condition, Value e, Value t) { - build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t); -} - -//===----------------------------------------------------------------------===// -// ShapeOp -//===----------------------------------------------------------------------===// - -namespace { -// Validates Shape/ShapeN/VariableShape operand and associated result types. -LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, - Type result_type, - int variadic_idx = -1) { - std::string variadic_idx_str = - variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str(); - - auto result_ranked_type = result_type.dyn_cast(); - if (!result_ranked_type) return success(); - if (result_ranked_type.getShape().size() != 1) - return op->emitOpError("requires 1D type for result") << variadic_idx_str; - - auto operand_ranked_type = operand_type.dyn_cast_or_null(); - if (operand_ranked_type) { - // The operand is a ranked tensor. - if (result_ranked_type.hasStaticShape() && - !operand_ranked_type.getShape().empty() && - result_ranked_type.getDimSize(0) != - operand_ranked_type.getShape().size()) - return op->emitOpError("requires dimension size of result") - << variadic_idx_str << " to match rank of operand" - << variadic_idx_str; - } else if (result_ranked_type.hasStaticShape()) { - // The operand is an unranked tensor, print a warning if the result - // is static. - // Note: We do not handle this situation as an error, this would be too - // restrictive due to incompleteness of shape inference at this point. - op->emitWarning("has static shape result") - << variadic_idx_str << " for unranked operand" << variadic_idx_str; - } - - Type element_type = result_ranked_type.getElementType(); - if (!element_type.isSignlessInteger(32) && - !element_type.isSignlessInteger(64)) - return op->emitOpError("requires int32 or int64 return type for result") - << variadic_idx_str; - - return success(); -} -} // anonymous namespace - -static LogicalResult Verify(ShapeOp op) { - return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType()); -} - -// Converts shape of the given type to attribute if it is of ranked tensor type. -// Returned attribute has integer elements of the given width. -static Attribute ConvertShapeToAttr(Type input_ty, int out_width) { - auto ranked_ty = input_ty.dyn_cast(); - if (!ranked_ty || !ranked_ty.hasStaticShape()) return {}; - - auto shape = ranked_ty.getShape(); - int rank = shape.size(); - - SmallVector dimensions; - dimensions.reserve(rank); - for (int i = 0; i < rank; ++i) - dimensions.push_back(APInt(out_width, shape[i])); - - auto result_type = RankedTensorType::get( - {rank}, IntegerType::get(out_width, input_ty.getContext())); - return DenseElementsAttr::get(result_type, dimensions); -} - -OpFoldResult ShapeOp::fold(ArrayRef operands) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); - return ConvertShapeToAttr(getOperand().getType(), width); -} - -void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input, - BoolAttr use32Bit) { - auto rankedTensorType = input.getType().dyn_cast(); - int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1; - auto out_type = use32Bit.getValue() ? builder.getIntegerType(32) - : builder.getIntegerType(64); - return ShapeOp::build(builder, result, - RankedTensorType::get({rank}, out_type), input); -} - -//===----------------------------------------------------------------------===// -// ShapeNOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(ShapeNOp op) { - const size_t num_tensors = op.N(); - - if (op.getNumOperands() != num_tensors) - return op.emitOpError() << "requires " << num_tensors << " operand(s), got " - << op.getNumOperands() << " operand(s)"; - - if (op.getNumResults() != num_tensors) - return op.emitOpError() << "requires " << num_tensors << " result(s), got " - << op.getNumResults() << " result(s)"; - - for (auto i : llvm::seq(0, num_tensors)) { - auto verification = VerifyShapeOperandAndResult( - op, op.getOperand(i).getType(), op.getResult(i).getType(), i); - if (failed(verification)) return verification; - } - - return success(); -} - -LogicalResult ShapeNOp::fold(ArrayRef operands, - SmallVectorImpl &results) { - if (getNumOperands() == 0) return success(); - int width = - getType(0).cast().getElementType().getIntOrFloatBitWidth(); - - for (Type input_ty : getOperandTypes()) { - OpFoldResult result = ConvertShapeToAttr(input_ty, width); - if (!result) return failure(); - - results.push_back(result); - } - return success(); -} - -// TODO(hinsu): Add canonicalization pattern for ShapeN ops that don't have all -// static input shapes. Replacing output values corresponding to static input -// types may enable optimizations in users of the values. - -//===----------------------------------------------------------------------===// -// SizeOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// -// * Input type, if is a ranked tensor, has at most INT32_MAX dimensions. -// -static LogicalResult Verify(SizeOp op) { - if (!HasRankAtMost(op.input(), std::numeric_limits::max())) - return op.emitOpError( - "requires ranked input tensor to be of rank INT32_MAX or less"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// SliceOp -//===----------------------------------------------------------------------===// - -// Verifies that: -// -// - operands begin and size are 1D with the same number of elements. -// - if the input is a ranked tensor, the rank of the input equals the number -// of elements in operands begin and size. -// - if begin are constants, that -// 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i] -// - if begins aren't constant but the input is a ranked tensor, that -// size[i] <= input_ty.getShape()[i] -// -static LogicalResult Verify(SliceOp op) { - RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin()); - if (begin_ty && begin_ty.getRank() != 1) { - return op.emitOpError() << "requires begin operand to be 1D tensor"; - } - - RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size()); - if (size_ty && size_ty.getRank() != 1) { - return op.emitOpError() << "requires size operand to be 1D tensor"; - } - - if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() || - !size_ty.hasStaticShape()) - return success(); - - if (begin_ty.getNumElements() != size_ty.getNumElements()) { - return op.emitOpError() << "requires begin and size operands to have the" - " same number of elements"; - } - - auto input_ty = op.input().getType().dyn_cast(); - if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) { - return op.emitOpError() << "requires number of elements in begin and size" - "are equal to input rank"; - } - - DenseIntElementsAttr begin_indices; - if (matchPattern(op.begin(), m_Constant(&begin_indices))) { - DenseIntElementsAttr slice_sizes; - bool constant_slice_sizes = - matchPattern(op.size(), m_Constant(&slice_sizes)); - int dim = 0; - for (const APInt &raw_begin_index : begin_indices.getValues()) { - int64_t begin_index = raw_begin_index.getSExtValue(); - int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1; - int64_t slice_size = constant_slice_sizes - ? slice_sizes.getValue(dim).getSExtValue() - : 0; - if (slice_size == -1 && input_size != -1) { - slice_size = input_size - begin_index; - } - if (begin_index < 0 || - (input_size != -1 && begin_index + slice_size > input_size)) { - return op.emitOpError() - << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di"; - } - ++dim; - } - } else if (input_ty) { - // If the inputs are ranked, we can do a few more sanity checks. - DenseIntElementsAttr slice_sizes; - if (matchPattern(op.size(), m_Constant(&slice_sizes))) { - auto input_shape = input_ty.getShape(); - for (int64_t i = 0; i < input_ty.getRank(); ++i) { - int64_t slice_size = slice_sizes.getValue(i).getInt(); - int64_t input_size = input_shape[i]; - if (slice_size != -1 && input_size != -1 && slice_size > input_size) { - return op.emitOpError() << "requires size[i] <= Di, even if begin[i] " - "is unknown at compile time"; - } - } - } - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// SoftmaxOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(SoftmaxOp op) { - if (!HasRankAtLeast(op.logits(), 1)) { - return op.emitOpError("requires operand to have rank at least 1"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// SoftmaxCrossEntropyWithLogitsOp -//===----------------------------------------------------------------------===// - -// Verifies that, -// -// * Input types are broadcast compatible and the broadcasted type has rank two. -// -static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { - auto broadcasted_ty = OpTrait::util::getBroadcastedType( - op.features().getType(), op.labels().getType()) - .dyn_cast_or_null(); - if (!broadcasted_ty || - (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2)) - return op.emitOpError( - "requires features and labels to be broadcast compatible to rank two"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// SparseSoftmaxCrossEntropyWithLogitsOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) { - if (!IsOfRankOrUnranked(op.features(), 2)) { - return op.emitOpError("requires features operand of rank two"); - } - if (!IsOfRankOrUnranked(op.labels(), 1)) { - return op.emitOpError("requires labels operand of rank one"); - } - auto features_ty = op.features().getType().dyn_cast(); - auto labels_ty = op.labels().getType().dyn_cast(); - if (features_ty && labels_ty) { - int64_t features_batches = features_ty.getDimSize(0); - int64_t labels_batches = labels_ty.getDimSize(0); - if (!ShapedType::isDynamic(features_batches) && - !ShapedType::isDynamic(labels_batches) && - features_batches != labels_batches) - return op.emitOpError( - "requires features and labels with matching first dimension"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// SplitOp -//===----------------------------------------------------------------------===// - -// Verifies the input and split dimension operands for tf.Split/tf.SplitV. -// Writes the split dimension's index (adjusted with input rank) via `dim_index` -// if it's a constant. -template -LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) { - *dim_index = llvm::None; - - Value split_dim = op.split_dim(); - if (auto split_dim_type = split_dim.getType().dyn_cast()) - if (split_dim_type.getRank() != 0) - return op.emitOpError( - "split dimension should be an integer scalar tensor"); - - // We can perform further verification if the input tensor to be split has - // known rank and the split dimension tensor is a constant. - - auto input_type = op.value().getType().template dyn_cast(); - if (!input_type) return success(); - - int64_t input_rank = input_type.getRank(); - if (input_rank == 0) - return op.emitOpError("cannot split scalar input tensor"); - - DenseIntElementsAttr split_dim_attr; - if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success(); - - int64_t index = (*split_dim_attr.begin()).getSExtValue(); - - if (index + input_rank < 0 || index >= input_rank) { - return op.emitOpError("split dimension must be in range [-") - << input_rank << ", " << input_rank << ")"; - } - - if (index < 0) index += input_rank; - *dim_index = index; - - return success(); -} - -static LogicalResult Verify(SplitOp op) { - Optional dim_index; - if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); - if (!dim_index) return success(); - - int64_t input_dim_size = - op.value().getType().cast().getDimSize(*dim_index); - if (input_dim_size == ShapedType::kDynamicSize) return success(); - - if (input_dim_size % op.getNumResults() != 0) - return op.emitOpError("dimension #") - << *dim_index << " not divisible by the number of result tensors"; - - return success(); -} - -//===----------------------------------------------------------------------===// -// SplitVOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(SplitVOp op) { - auto split_sizes_type = - op.size_splits().getType().dyn_cast(); - if (!split_sizes_type) return success(); - - if (split_sizes_type.getRank() != 1 || - split_sizes_type.getDimSize(0) != op.getNumResults()) - return op.emitOpError("split sizes should be a 1D tensor of ") - << op.getNumResults() << " elements"; - - Optional dim_index = 0; - if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); - if (!dim_index) return success(); - - int64_t input_dim_size = - op.value().getType().cast().getDimSize(*dim_index); - if (input_dim_size == ShapedType::kDynamicSize) return success(); - - // If split sizes come from a constant, they must sum to the dimension size - // along split_dim, and we can have no more than one dynamic dimension. - DenseIntElementsAttr split_sizes_attr; - if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr))) - return success(); - - int64_t total_dim_size = 0; // Total dimension size assigned to splits - llvm::Optional dynamic_dim_index; - - SmallVector split_sizes; - split_sizes.reserve( - split_sizes_attr.getType().cast().getNumElements()); - - for (auto dim : llvm::enumerate(split_sizes_attr)) { - int64_t dim_val = dim.value().getSExtValue(); - split_sizes.push_back(dim_val); - if (dim_val == ShapedType::kDynamicSize) { - // We cannot have more than one dynamic dimension. - if (dynamic_dim_index) - return op.emitOpError( - "cannot have more than one dynamic dimension in split sizes"); - dynamic_dim_index = dim.index(); - } else { - total_dim_size += dim_val; - } - } - - if (!dynamic_dim_index && total_dim_size != input_dim_size) - return op.emitOpError( - "split sizes must sum up to the dimension size along split " - "dimension, found ") - << total_dim_size << " vs " << input_dim_size; - - if (dynamic_dim_index && total_dim_size > input_dim_size) - return op.emitOpError( - "split sizes must sum up to be less than or equal to the " - "dimension size along split dimension, found ") - << total_dim_size << " vs " << input_dim_size; - - return success(); -} - -//===----------------------------------------------------------------------===// -// SquareOp -//===----------------------------------------------------------------------===// - -void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// SubOp -//===----------------------------------------------------------------------===// - -void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -OpFoldResult SubOp::fold(ArrayRef operands) { - return IdentityArithmeticOpFolder(*this, operands); -} - -//===----------------------------------------------------------------------===// -// SumOp -//===----------------------------------------------------------------------===// - -void SumOp::build(OpBuilder &builder, OperationState &result, Value input, - Value reduction_indices, BoolAttr keep_dims) { - Type out_ty = - InferReductionOpType(input, reduction_indices, keep_dims, &builder); - build(builder, result, out_ty, input, reduction_indices, keep_dims); -} - -//===----------------------------------------------------------------------===// -// StridedSliceOp -//===----------------------------------------------------------------------===// - -// TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to -// tf.SliceOp if both of the following are true: -// - All strides have a known value equal to 1 -// - No masks are set (or masks can be applied by transforming the inputs to -// Slice) - -// Verifies that, -// -// - begin, end and strides operands are 1D and they have the same number of -// elements. Here, the number of elements should be less than 32 to support -// 32-bit mask attributes. -// - None of the strides values are zero. -// - Ellipsis mask can have at most one bit set. - -template -static LogicalResult VerifyStridedSliceBase(OpTy op) { - // Expected size for operands begin, end and strides vector operands. - int64_t expected_size = -1; - - for (Value val : {op.begin(), op.end(), op.strides()}) { - auto operand_ty = val.getType().dyn_cast(); - if (!operand_ty || !operand_ty.hasStaticShape()) { - // TensorFlow constant ops may have non-static shape because the shape is - // not propagated during constant folding. If the defining op for this - // operand is a constant op, use the constant op's attribute to get the - // actual shape. - DenseIntElementsAttr attr; - if (!matchPattern(val, m_Constant(&attr))) continue; - operand_ty = attr.getType(); - } - - if (operand_ty.getRank() != 1) - return op.emitOpError() - << "requires begin, end and strides to be 1D tensors"; - - int64_t length = operand_ty.getDimSize(0); - if (length == -1) continue; - - if (expected_size == -1) { - // This op uses 32-bit masks. - if (length >= 32) - return op.emitOpError( - "requires begin, end and strides operands with less than 32 " - "elements"); - - expected_size = length; - } else if (length != expected_size) { - return op.emitOpError() << "requires begin, end and strides to have the " - "same number of elements"; - } - } - - // If strides are constants, verify that none of the element is zero. - DenseIntElementsAttr strides; - if (matchPattern(op.strides(), m_Constant(&strides))) { - if (llvm::is_contained(strides.getValues(), 0)) - return op.emitOpError("requires non-zero strides"); - } - - // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there - // exists only no more than one ellipsis. - uint32_t ellipsis_mask = op.ellipsis_mask().getZExtValue(); - if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask)) - return op.emitOpError("cannot have multiple ellipses"); - - return success(); -} - -// Clamps the given `val`: returns `low` if `val` is less than `low`; returns -// `high` if `high` is less than `val`; otherwise returns `val`. -template -constexpr const T &Clamp(const T &val, const T &low, const T &high) { - assert(!(high < low)); - return (val < low) ? low : (high < val) ? high : val; -} - -// Checks if the `index` bit of `val` is set. -template -constexpr bool IsSet(const T &val, unsigned index) { - return (val & (1 << index)) != 0; -} - -// Sets the `index` bit of `val`. -template -constexpr void Set(T &val, unsigned index) { - val |= (1 << index); -} - -// Unset the `index` bit of `val`. -template -constexpr void Unset(T &val, unsigned index) { - val &= ~(1 << index); -} - -// Copy the `src_index` bit of `src` to `dst_index` bit of `dst`. -template -constexpr void CopyBit(const T &src, unsigned src_index, T &dst, - unsigned dst_index) { - if (IsSet(src, src_index)) - Set(dst, dst_index); - else - Unset(dst, dst_index); -} - -// The sparse spec of strided slice does not correspond to the number of -// dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2, -// 4, 8) would have dims = 2. -struct SparseSliceSpec { - int64_t dims; - int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask; - const ArrayRef &begin; - const ArrayRef &end; - const ArrayRef &strides; -}; - -// The dense spec of strided slice is the canonicalized version of sparse spec. -// The number of dimensions of dense spec correspond to the number of dimensions -// in operand tensor. -struct DenseSliceSpec { - int64_t dims; - int32_t begin_mask, end_mask, shrink_axis_mask; - SmallVectorImpl &begin; - SmallVectorImpl &end; - SmallVectorImpl &strides; -}; - -// Make a sparse spec into a dense index spec. -// The sparse spec does not correspond to the number of dimensions -// Make a dense spec that corresponds to the number of dimensions -// -// For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then -// we need to produce the missing begin_mask, end_mask for the first two -// dimensions i.e. foo[:, :, 3:, 2]. -static void BuildDenseSliceSpec(const SparseSliceSpec &sparse, - DenseSliceSpec *dense) { - // Build expanded dense begin, end, strides, begin_mask, end_mask, and - // shrink_axis_mask. - dense->begin.resize(dense->dims); - dense->end.resize(dense->dims); - dense->strides.resize(dense->dims); - dense->begin_mask = 0; - dense->end_mask = 0; - dense->shrink_axis_mask = 0; - - // Count number of new_axis after ellipsis. This helps in calculating the - // number of dimensions ellipsis represents in the sparse spec. - bool ellipsis_seen = false; - int num_new_axis_after_ellipsis = 0; - for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) { - if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index)) - num_new_axis_after_ellipsis++; - if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true; - } - - int dense_index = 0; - for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) { - if (IsSet(sparse.new_axis_mask, sparse_index)) continue; - if (IsSet(sparse.ellipsis_mask, sparse_index)) { - auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) + - 1 + num_new_axis_after_ellipsis, - dense->dims); - // Expand ellipsis into the appropriate dense indices. From current index - // until next_index, all dimensions would have begin and end masks set and - // stride 1, i.e., get all elements in those dimensions. - for (; dense_index < next_index; ++dense_index) { - dense->begin[dense_index] = dense->end[dense_index] = 0; - dense->strides[dense_index] = 1; - Set(dense->begin_mask, dense_index); - Set(dense->end_mask, dense_index); - } - continue; - } - assert(dense_index < dense->dims); - // Copy over the sparse indices to dense indices if ellipsis_mask and - // new_axis_mask are not set. - dense->begin[dense_index] = sparse.begin[sparse_index]; - dense->end[dense_index] = sparse.end[sparse_index]; - dense->strides[dense_index] = sparse.strides[sparse_index]; - CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index); - CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index); - CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask, - dense_index); - dense_index++; - } -} - -// For the given `input_shape`, calculates the sliced shape using the given -// `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and -// `shrink_axis_mask` masks. Updates the result back to `input_shape`. If -// `shrink_axis_mask` is not zero, this function will not drop the corresponding -// dimensions in `input_shape`; it will turn them into 1s. At the same time, -// canonicalizes `begin`, `end`, and `strides. The calculation follows -// tf.StridedSlice op semantics. -static void CalculateSlicedShapeFromDenseIndices( - MutableArrayRef input_shape, int32_t begin_mask, int32_t end_mask, - int32_t shrink_axis_mask, MutableArrayRef begin, - MutableArrayRef end, MutableArrayRef stride) { - assert(input_shape.size() <= 32); // Only 32-bit masks are supported. - - // Make sure ranges' ranks are consistent with the input. - assert(input_shape.size() == begin.size()); - assert(input_shape.size() == end.size()); - assert(input_shape.size() == stride.size()); - - for (int i = 0, e = input_shape.size(); i < e; ++i) { - if (ShapedType::isDynamic(input_shape[i])) continue; - - int64_t dim_i = input_shape[i]; - int64_t begin_i = begin[i]; - int64_t end_i = end[i]; - int64_t stride_i = stride[i]; - - // [0]: mask for begin, [1]: mask for end - int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)}; - // [0]: bound for begin, [1]: bound for end - int64_t bounds[] = {stride_i > 0 ? 0 : -1, - stride_i > 0 ? dim_i : dim_i - 1}; - - // Canonicalizes the given range `point` (begin/end) according to the - // current dimension. `c` means case: 0 for begin, 1 for end. - auto canonicalize = [&](int64_t point, int c) { - if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1]; - - // Add dim as offset to negative range point. - point = point < 0 ? dim_i + point : point; - return Clamp(point, bounds[0], bounds[1]); - }; - - begin_i = canonicalize(begin_i, 0); - end_i = canonicalize(end_i, 1); - - int64_t interval_len = end_i - begin_i; - int64_t size_i = 0; - // If internal length is zero or has different sign from stride, it's a - // degenerated case: we are slicing nothing. Otherwise, calculate the sliced - // size. - if (interval_len != 0 && (interval_len < 0) == (stride_i < 0)) - size_i = (interval_len / stride_i) + (interval_len % stride_i != 0); - - begin[i] = begin_i; - if (IsSet(shrink_axis_mask, i)) { - // Shrink this dimension. It means we only take the element at begin_i. - input_shape[i] = 1; - end[i] = begin_i + 1; - stride[i] = 1; - } else { - input_shape[i] = size_i; - end[i] = end_i; - stride[i] = stride_i; - } - } -} - -// For the given `input_shape`, calculates the sliced shape using the given -// `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`, -// `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks. -// Updates the result back to `input_shape`. -static void CalculateSlicedShapeFromSparseIndices( - MutableArrayRef input_shape, ArrayRef sparse_begin, - ArrayRef sparse_end, ArrayRef sparse_strides, - int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask, - int32_t new_axis_mask, int32_t shrink_axis_mask, - SmallVectorImpl *begin, SmallVectorImpl *end, - SmallVectorImpl *stride) { - int64_t num_sparse_indices = sparse_begin.size(); - SparseSliceSpec sparse = {num_sparse_indices, begin_mask, end_mask, - ellipsis_mask, new_axis_mask, shrink_axis_mask, - sparse_begin, sparse_end, sparse_strides}; - - // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is - // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields - // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...]. - if (sparse.ellipsis_mask == 0) { - Set(sparse.ellipsis_mask, sparse.dims); - sparse.dims++; - } - - int64_t dims = input_shape.size(); - DenseSliceSpec dense = {dims, - /*begin_mask = */ 0, - /*end_mask = */ 0, - /*shrink_axis_mask = */ 0, - *begin, - *end, - *stride}; - - BuildDenseSliceSpec(sparse, &dense); - CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask, - dense.end_mask, dense.shrink_axis_mask, - *begin, *end, *stride); -} - -bool StridedSliceOp::GetSlicedBoundRanges( - SmallVectorImpl *slice_begin, SmallVectorImpl *slice_end, - SmallVectorImpl *slice_stride) { - // TODO(hinsu): Support lowering for ops with dynamic begin and end values - // when it is possible to derive indices based on mask attributes. - DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; - if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) || - !matchPattern(end(), m_Constant(&sparse_end_attr)) || - !matchPattern(strides(), m_Constant(&sparse_strides_attr))) - return false; - - auto input_ty = this->input().getType().dyn_cast(); - if (!input_ty || !input_ty.hasStaticShape()) return false; - auto input_shape = llvm::to_vector<4>(input_ty.getShape()); - - SmallVector sparse_begin, sparse_end, sparse_strides; - - for (const APInt &index : sparse_begin_attr) - sparse_begin.push_back(index.getSExtValue()); - for (const APInt &index : sparse_end_attr) - sparse_end.push_back(index.getSExtValue()); - for (const APInt &stride : sparse_strides_attr) - sparse_strides.push_back(stride.getSExtValue()); - - CalculateSlicedShapeFromSparseIndices( - input_shape, sparse_begin, sparse_end, sparse_strides, - begin_mask().getZExtValue(), end_mask().getZExtValue(), - ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), - shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); - return true; -} - -//===----------------------------------------------------------------------===// -// StridedSliceGradOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(StridedSliceGradOp op) { - auto shape_type = op.shape().getType().dyn_cast(); - if (shape_type && shape_type.getRank() != 1) - return op.emitOpError("'shape' operand must be 1D tensor, but got ") - << shape_type.getRank() << "D tensor"; - - if (failed(VerifyStridedSliceBase(op))) return failure(); - - // TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with - // the sliced type from StridedSlice. - - return success(); -} - -bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( - SmallVectorImpl *input_shape, - SmallVectorImpl *slice_begin, SmallVectorImpl *slice_end, - SmallVectorImpl *slice_stride) { - DenseIntElementsAttr shape_attr; - DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; - if (!matchPattern(shape(), m_Constant(&shape_attr)) || - !matchPattern(begin(), m_Constant(&sparse_begin_attr)) || - !matchPattern(end(), m_Constant(&sparse_end_attr)) || - !matchPattern(strides(), m_Constant(&sparse_strides_attr))) - return false; - - int rank = std::distance(shape_attr.begin(), shape_attr.end()); - - input_shape->clear(); - input_shape->reserve(rank); - for (const APInt &dim : shape_attr) - input_shape->push_back(dim.getSExtValue()); - - SmallVector sparse_begin, sparse_end, sparse_strides; - - for (const APInt &index : sparse_begin_attr) - sparse_begin.push_back(index.getSExtValue()); - for (const APInt &index : sparse_end_attr) - sparse_end.push_back(index.getSExtValue()); - for (const APInt &stride : sparse_strides_attr) - sparse_strides.push_back(stride.getSExtValue()); - - CalculateSlicedShapeFromSparseIndices( - *input_shape, sparse_begin, sparse_end, sparse_strides, - begin_mask().getZExtValue(), end_mask().getZExtValue(), - ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), - shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); - return true; -} - -//===----------------------------------------------------------------------===// -// TensorListReserveOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(TensorListReserveOp op) { - if (!IsOfRankOrUnranked(op.element_shape(), 0) && - !IsOfRankOrUnranked(op.element_shape(), 1)) { - return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); - } - - if (!IsOfRankOrUnranked(op.num_elements(), 0)) { - return op.emitOpError("requires num_elements operand to be 0D tensor"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// TensorListElementShapeOp -//===----------------------------------------------------------------------===// - -OpFoldResult TensorListElementShapeOp::fold(ArrayRef operands) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); - auto variant_type = - getElementTypeOrSelf(getOperand().getType()).cast(); - if (variant_type.getSubtypes().empty()) return {}; - return ConvertShapeToAttr(variant_type.getSubtypes()[0], width); -} - -//===----------------------------------------------------------------------===// -// TensorListStackOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(TensorListStackOp op) { - if (!IsOfRankOrUnranked(op.element_shape(), 0) && - !IsOfRankOrUnranked(op.element_shape(), 1)) { - return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); - } - return success(); -} - -//===----------------------------------------------------------------------===// -// TensorScatterUpdateOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(TensorScatterUpdateOp op) { - if (!HasRankAtLeast(op.tensor(), 1)) - return op.emitOpError( - "requires tensor operand to have at least 1 dimension"); - if (!HasRankAtLeast(op.indices(), 1)) - return op.emitOpError( - "requires indices operand to have at least 1 dimension"); - if (!HasRankAtLeast(op.updates(), 1)) - return op.emitOpError( - "requires updates operand to have at least 1 dimension"); - - auto tensor_ty = op.tensor().getType().dyn_cast(); - auto indices_ty = op.indices().getType().dyn_cast(); - if (!tensor_ty || !indices_ty) return success(); - - int64_t num_index_dims = indices_ty.getShape().back(); - if (ShapedType::isDynamic(num_index_dims)) return success(); - - if (num_index_dims > tensor_ty.getRank()) - return op.emitOpError( - "requires tensor operand with rank greater than or equal to the " - "indices operand's last dimensions"); - return success(); -} - -//===----------------------------------------------------------------------===// -// TopKV2Op -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(TopKV2Op op) { - if (!HasRankAtLeast(op.input(), 1)) - return op.emitOpError( - "requires input operand to have at least 1 dimension"); - - if (!IsOfRankOrUnranked(op.k(), 0)) - return op.emitOpError("requires k operand to be 0D tensor"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// ToBoolOp -//===----------------------------------------------------------------------===// - -namespace { -// If the input to ToBoolOp is a `tensor`, then the ToBoolOp is an identity -// function and can be removed. -class ToBoolOfZeroDBoolTensor : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ToBoolOp op, - PatternRewriter &rewriter) const override { - if (auto type = op.getOperand().getType().dyn_cast()) { - if (type.getRank() == 0 && type.getElementType().isInteger(1)) { - rewriter.replaceOp(op, op.getOperand()); - return success(); - } - } - return failure(); - } -}; -} // namespace - -void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// TransposeOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(TransposeOp op) { - // TODO(hinsu): Verify using a custom verifier that, - // * Transpose permutation is 1-D of size equal to the rank of the first - // input, if the shapes are partially known. Requires use of a more - // restrictive type than TF_Tensor. - // * Result shape dimensions are possible based on the input shape. - return success(); -} - -// TODO(jpienaar): perm could be optional too. -void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, - Value perm) { - auto x_type = x.getType().cast(); - // If value is unranked, then so is results. - if (!x_type.hasRank()) - return TransposeOp::build(builder, result, - UnrankedTensorType::get(x_type.getElementType()), - x, perm); - - // TODO(jpienaar): Handle unknown perm case. - - // TODO(jpienaar): Extract utility function. - auto etype = x_type.cast().getElementType(); - DenseIntElementsAttr attr_shape; - if (matchPattern(perm, m_Constant(&attr_shape))) { - llvm::SmallVector const_shape; - if (attr_shape.isSplat()) { - const_shape.assign( - attr_shape.getNumElements(), - x_type.getDimSize((*attr_shape.begin()).getSExtValue())); - } else { - const_shape.reserve(attr_shape.getNumElements()); - for (const auto &dim : attr_shape) - const_shape.push_back(x_type.getDimSize(dim.getSExtValue())); - } - return TransposeOp::build( - builder, result, RankedTensorType::get(const_shape, etype), x, perm); - } - return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x, - perm); -} - -namespace { - -OpFoldResult FoldIdentityTranspose(TransposeOp op) { - auto const_perm = dyn_cast_or_null(op.perm().getDefiningOp()); - if (!const_perm) return {}; - - auto const_value = const_perm.value(); - const auto elements = const_value.getValues(); - - for (auto it : llvm::enumerate(elements)) { - if (it.index() != it.value()) return {}; - } - - // TODO(jpienaar): Remove if/when we handle this more generally. - if (op.getType() != op.x().getType()) { - // If the types don't match then only fold if all the operands are in the TF - // dialect. - for (auto user : op.getOperation()->getUsers()) - if (user->getDialect() != op.getDialect()) return {}; - } - - return op.x(); -} - -OpFoldResult FoldCancellableTranspose(TransposeOp op) { - // Operand is a TransposeOp. - auto transpose = dyn_cast_or_null(op.x().getDefiningOp()); - if (!transpose) return {}; - - // Permutations defined by constant operations. - auto perm0 = dyn_cast_or_null(op.perm().getDefiningOp()); - auto perm1 = dyn_cast_or_null(transpose.perm().getDefiningOp()); - if (!perm0 || !perm1) return {}; - - // With permutation indices that cancel each other - auto perm0_value = perm0.value().cast(); - auto perm1_value = perm1.value().cast(); - if (!AreCancellablePermutations(perm0_value, perm1_value)) return {}; - - return transpose.x(); -} - -} // namespace - -OpFoldResult TransposeOp::fold(ArrayRef operands) { - if (auto folded = FoldIdentityTranspose(*this)) return folded; - if (auto folded = FoldCancellableTranspose(*this)) return folded; - return {}; -} - -//===----------------------------------------------------------------------===// -// TruncateDivOp -//===----------------------------------------------------------------------===// - -void TruncateDivOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// UnpackOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(UnpackOp op) { - auto value_type = op.value().getType().dyn_cast(); - if (!value_type) return success(); - - int64_t value_rank = value_type.getRank(); - int64_t axis = op.axis().getSExtValue(); - if (axis < -value_rank || axis >= value_rank) - return op.emitOpError("axis attribute must be in the range of [-") - << value_rank << ", " << value_rank << ')'; - - axis = GetDimForAxis(axis, value_rank); - int64_t dim_size = value_type.getDimSize(axis); - if (ShapedType::isDynamic(dim_size)) return success(); - - if (dim_size != op.getNumResults()) - return op.emitOpError("result count must be equal to ") << dim_size; - - return success(); -} - -//===----------------------------------------------------------------------===// -// Unsorted segment reduction ops -//===----------------------------------------------------------------------===// - -template -static LogicalResult VerifyUnsortedSegmentReduction(Op op) { - if (!HasRankAtMost(op.num_segments(), 0)) - return op.emitOpError("number of segments should be a 0-D tensor"); - - auto data_type = op.data().getType().template dyn_cast(); - auto segment_ids_type = - op.segment_ids().getType().template dyn_cast(); - if (data_type && segment_ids_type) { - if (data_type.getRank() < segment_ids_type.getRank()) - return op.emitOpError( - "requires segment ids rank to be less than or equal to data's rank"); - - int index = 0; - for (auto shape_pair : - llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) { - int64_t segment_id_dim = std::get<0>(shape_pair); - int64_t data_dim = std::get<1>(shape_pair); - if (!ShapedType::isDynamic(segment_id_dim) && - !ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim) - return op.emitOpError( - "requires segment ids shape to be a prefix of data shape, " - "but dimension #") - << index << " differs: " << segment_id_dim << " vs. " - << data_dim; - ++index; - } - } - - DenseIntElementsAttr num_segments_attr; - if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) { - int64_t num_segments = (*num_segments_attr.begin()).getSExtValue(); - if (num_segments < 0) - return op.emitOpError("num of segments cannot be negative"); - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// VariableShapeOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(VariableShapeOp op) { - auto input_type = op.input().getType().cast(); - if (input_type.hasStaticShape() && input_type.getNumElements() != 1) - return op.emitOpError("requires input to have one resource"); - - auto resource_type = input_type.getElementType().cast(); - auto subtypes = resource_type.getSubtypes(); - switch (subtypes.size()) { - case 1: - return VerifyShapeOperandAndResult( - op, resource_type.getSubtypes().front(), op.getType()); - case 0: - return VerifyShapeOperandAndResult(op, Type(), op.getType()); - default: - return op.emitOpError( - "requires resource input type to have at most 1 subtype"); - } -} - -OpFoldResult VariableShapeOp::fold(ArrayRef operands) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); - auto resource_type = - getElementTypeOrSelf(getOperand().getType()).cast(); - if (resource_type.getSubtypes().empty()) return {}; - return ConvertShapeToAttr(resource_type.getSubtypes()[0], width); -} - -//===----------------------------------------------------------------------===// -// WhileOp -//===----------------------------------------------------------------------===// - -static LogicalResult Verify(WhileOp op) { - auto module = op.getParentOfType(); - auto cond_fn = module.lookupSymbol(op.cond()); - auto body_fn = module.lookupSymbol(op.body()); - if (!cond_fn) { - return op.emitOpError("cond refers to an undefined function : ") - << op.cond(); - } - if (!body_fn) { - return op.emitOpError("body refers to an undefined function : ") - << op.body(); - } - - auto cond_fn_type = cond_fn.getType(); - auto body_fn_type = body_fn.getType(); - - // Verify that the cond function has exactly one result. - if (cond_fn_type.getNumResults() != 1) - return op.emitOpError("requires cond function to have exactly one result"); - - SmallVector operands(op.getOperandTypes()); - - // Collect all the type lists for the op so that different pairs of type lists - // can be compared for the compatibility. - constexpr int kNumTypeLists = 5; - const std::array>, kNumTypeLists> - type_lists = {{ - {"operand", operands}, - {"body function result", body_fn_type.getResults()}, - {"result", op.getResultTypes()}, - {"cond function input", cond_fn_type.getInputs()}, - {"body function input", body_fn_type.getInputs()}, - }}; - - // A pair of type lists should be cast compatible with each other if one is - // converted to the another for a function call or assignment or there is a - // common source of inputs for both. Therefore, the While op requires the - // following pairs of type lists to be cast compatible for the tensor_cast - // operation: - // - // * Operands and cond inputs to call the cond function before the - // first iteration. - // * Operands and body inputs to call the body function for the first - // iteration if the cond functions returns True or equivalent result. - // * Operands and results to assign cond function arguments to op results if - // the cond function returns False or equivalent result. - // * All three pairs using cond inputs, body inputs and results as operand is - // a common source for all three. - // * Body result and cond inputs to call the cond function for the subsequent - // iterations. Similarly, Body result should be compatible with body inputs - // and op results. - // - // Note that the operands and body results need not be compatible as they are - // never converted from one to the another nor there is a common source - // tensors. Compatibility requirement is not transitive. - - for (int i = 0; i < kNumTypeLists; ++i) { - // Skip the first pair as the While op operands and body function results - // does not need to be compatible with each other. - for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) { - auto &a = type_lists[i]; - auto &b = type_lists[j]; - - int a_size = a.second.size(); - if (a_size != b.second.size()) - return op.emitOpError( - llvm::formatv("requires the number of {0}s to be equal to the " - "number of {1}s. Found {2} and {3}, respectively", - a.first, b.first, a_size, b.second.size())); - - for (int idx = 0; idx < a_size; ++idx) { - auto a_type = a.second[idx]; - auto b_type = b.second[idx]; - - if (!AreCastCompatible({a_type, b_type})) - return op.emitError(llvm::formatv( - "{0} type {1} is incompatible with {2} type {3} at index {4}", - a.first, a_type, b.first, b_type, idx)); - } - } - } - return success(); -} - -//===----------------------------------------------------------------------===// -// WhileRegionOp -//===----------------------------------------------------------------------===// -static LogicalResult Verify(WhileRegionOp op) { - // Verify that the condition generates a single tensor result. - YieldOp yield = cast(op.cond().front().getTerminator()); - if (yield.getNumOperands() != 1) - return op.emitOpError() - << "condition should have a single tensor result"; - - auto cond_type = yield.getOperand(0).getType().dyn_cast(); - if (!cond_type || !cond_type.getShape().equals({}) || - !cond_type.getElementType().isInteger(/*width=*/1)) - return op.emitOpError() - << "condition should have a single tensor result"; - - // The body result types should match while op result types. - if (failed(VerifyRegionResults(op, op.body(), "body"))) return failure(); - - // Both condition and body should have same number and type of operands as - // the WhileRegion inputs. - const int num_inputs = op.getNumOperands(); - auto block_inputs_match_op_inputs = [&](Region ®ion, - StringRef name) -> LogicalResult { - Block &block = region.front(); - if (block.getNumArguments() != num_inputs) - return op.emitOpError() - << name << " should have same number of inputs (" << num_inputs - << ") as " << WhileRegionOp::getOperationName() << " but has " - << block.getNumArguments() << " inputs"; - - for (auto types_idx : llvm::enumerate( - llvm::zip(op.getOperandTypes(), block.getArgumentTypes()))) { - auto op_input_type = std::get<0>(types_idx.value()); - auto block_input_type = std::get<1>(types_idx.value()); - if (!AreCastCompatible({block_input_type, op_input_type})) - return op.emitOpError(llvm::formatv( - "{0} input type {1} is incompatible with {2} " - "input type {3} at index {4}", - name, block_input_type, WhileRegionOp::getOperationName(), - op_input_type, types_idx.index())); - } - return success(); - }; - - if (failed(block_inputs_match_op_inputs(op.cond(), "condition")) || - failed(block_inputs_match_op_inputs(op.body(), "body"))) - return failure(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// WhileRegionOp LoopLikeOpInterface -//===----------------------------------------------------------------------===// - -Region &WhileRegionOp::getLoopBody() { return body(); } - -bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) { - // If the Op defining the value exists and the defining op is outside the - // scope of this WhileRegion, then we can infer that its defined outside. - // The defining Op is outside the scope of this WhileRegion if this - // WhileRegionOp is not an ancestor of the defining op in the parent chain. - Operation *def_op = value.getDefiningOp(); - return def_op && !getOperation()->isAncestor(def_op); -} - -LogicalResult WhileRegionOp::moveOutOfLoop( - llvm::ArrayRef ops) { - // Move the hoisted value to just before the while. - Operation *while_op = this->getOperation(); - for (auto op : ops) op->moveBefore(while_op); - return success(); -} - -//===----------------------------------------------------------------------===// -// WhileRegionOp canonicalization -//===----------------------------------------------------------------------===// -namespace { -// Eliminate values that pass through the WhileRegionOp body. -struct WhileRegionEliminatePassThrough - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileRegionOp while_op, - PatternRewriter &rewriter) const override { - // Replace values that simply passthrough the body with extern values. The - // block arguments of body and while match and so the corresponding cond - // argument can be easily found. - int old_num_operands = while_op.getNumOperands(); - int new_num_operands = old_num_operands; - auto &body_block = while_op.body().front(); - auto &cond_block = while_op.cond().front(); - auto &yield = *body_block.getTerminator(); - - // Bit mask indicating which operands will be removed. - SmallVector removed_operand(old_num_operands, false); - - for (int op_idx : llvm::seq(0, old_num_operands)) { - auto body_arg = body_block.getArgument(op_idx); - if (body_arg == yield.getOperand(op_idx)) { - // Replace the use of the passthrough value with the while operand - // in the body and condition regions, as well as the while output (if - // type match) - // TODO(jurahul): Use PatternRewriter API for IR modification. - auto value = while_op.getOperand(op_idx); - if (body_arg.getType() == value.getType()) - body_arg.replaceAllUsesWith(value); - - auto cond_arg = cond_block.getArgument(op_idx); - if (cond_arg.getType() == value.getType()) - cond_arg.replaceAllUsesWith(value); - - auto result = while_op.getResult(op_idx); - if (result.getType() == value.getType()) - result.replaceAllUsesWith(value); - } - - // Now check if the operand is unused in both regions as well as the - // result. If so, mark it for removal. - if (body_block.getArgument(op_idx).use_empty() && - cond_block.getArgument(op_idx).use_empty() && - while_op.getResult(op_idx).use_empty()) { - removed_operand[op_idx] = true; - new_num_operands--; - } - } - - if (new_num_operands == old_num_operands) return failure(); - - // Compress the operands, region arguments, and outputs. - SmallVector new_while_operands; - SmallVector new_result_types; - new_while_operands.reserve(new_num_operands); - new_result_types.reserve(new_num_operands); - - // Build new operands and result type. - int next_idx = 0; - for (int op_idx : llvm::seq(0, old_num_operands)) { - if (removed_operand[op_idx]) continue; - new_while_operands.push_back(while_op.getOperand(op_idx)); - new_result_types.push_back(while_op.getResult(op_idx).getType()); - next_idx++; - } - - // Create the new while operation. - auto new_while_op = - rewriter.create(while_op.getLoc(), new_result_types, - new_while_operands, while_op.getAttrs()); - - // Move region bodies to the new while. - rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(), - new_while_op.cond().end()); - rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(), - new_while_op.body().end()); - - auto &new_cond_block = new_while_op.cond().front(); - auto &new_body_block = new_while_op.body().front(); - auto &new_yield = *new_body_block.getTerminator(); - - // Build a vector of new results. Also patch up the region bodies and yield. - SmallVector new_results; - next_idx = 0; - for (int op_idx : llvm::seq(0, old_num_operands)) { - if (removed_operand[op_idx]) { - new_cond_block.eraseArgument(next_idx); - new_body_block.eraseArgument(next_idx); - new_yield.eraseOperand(next_idx); - new_results.push_back(nullptr); - } else { - new_results.push_back(new_while_op.getResult(next_idx++)); - } - } - - rewriter.replaceOp(while_op, new_results); - return success(); - } -}; - -} // anonymous namespace - -void WhileRegionOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// XdivyOp -//===----------------------------------------------------------------------===// - -void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// TableGen'd op method definitions -//===----------------------------------------------------------------------===// - -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc" - //===----------------------------------------------------------------------===// // TF Dialect Interfaces //===----------------------------------------------------------------------===// namespace { +// Returns true if the op can be duplicated. +bool CanDuplicate(Operation *op) { + // If the op is marked with the cannot duplicate trait, it cannot be + // duplicated. + if (op->hasTrait()) return false; + + // If the op has no memory side effects, it can be duplicated. + if (MemoryEffectOpInterface::hasNoEffect(op)) return true; + + // If the op is marked stateless using the `is_stateless` attribute, that + // attribute determines if the op can be duplicated. + if (auto is_stateless = op->getAttrOfType("is_stateless")) + return is_stateless.getValue(); + + // Otherwise, assume ops can be duplicated by default. + return true; +} + +// Returns true of the given function has a single uses (within the scope +// of the module containing it and all parent modules). +bool HasSingleUse(FuncOp func) { + // Public function can have any number of external uses. + if (func.isPublic()) return false; + + // Return false if unexpected IR structure seen. + ModuleOp module = func.getParentOfType(); + if (!module) return false; + + // Inspect function uses in the containing module and all parent + // modules. + bool use_seen = false; + for (; module; module = module.getParentOfType()) { + auto func_uses_optional = + SymbolTable::getSymbolUses(func, &module.getBodyRegion()); + // Found an unknown use. + if (!func_uses_optional) return false; + + // If no uses in this scope, continue looking in parent module + SymbolTable::UseRange func_uses = func_uses_optional.getValue(); + if (func_uses.empty()) continue; + + // Check if multiple uses at this scope or another use already seen. + if (!llvm::hasSingleElement(func_uses) || use_seen) return false; + + // This is the first use seen. + use_seen = true; + + // If the function is private, no need to inspect parent modules. + if (func.isPrivate()) break; + } + + // No multiple uses seen. + return true; +} + struct TFInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -4433,8 +136,8 @@ struct TFInlinerInterface : public DialectInlinerInterface { // Analysis Hooks //===--------------------------------------------------------------------===// - // Defines the legality of inlinining 'src' region into the 'dest' region - // attached to a TF operation + // Returns if its legal to inline 'src' region into the 'dest' region + // attached to a TF operation. bool isLegalToInline(Region *dest, Region *src, BlockAndValueMapping &valueMapping) const final { // Allow inlining in regions attached to region based control flow @@ -4443,13 +146,17 @@ struct TFInlinerInterface : public DialectInlinerInterface { llvm::hasSingleElement(*src); } - // Defines the legality of inlining TF operations. - bool isLegalToInline(Operation *, Region *, + // Returns true if its legal to inline a TF operation `op` into the `dest` + // region. + bool isLegalToInline(Operation *op, Region *dest, BlockAndValueMapping &) const final { - // TODO(riverriddle) For now, enable inlining all operations. This isn't - // correct in the face of operations that cannot be duplicated, but this - // requires more intricate side-effect modeling. - return true; + // An op is legal to inline if either of the following conditions is true: + // (a) Its legal to duplicate the Op. + // (a) The Op is inside a single use function. If that function is inlined, + // post inlining, the function will be dead and eliminated from the IR. + // So there won't be any code duplication. + FuncOp func = op->getParentOfType(); + return !func || CanDuplicate(op) || HasSingleUse(func); } //===--------------------------------------------------------------------===// @@ -4476,17 +183,15 @@ struct TFInlinerInterface : public DialectInlinerInterface { // TF Dialect //===----------------------------------------------------------------------===// -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.cc.inc" - std::vector *TensorFlowDialect::additional_operation_hooks_ = new std::vector(); TensorFlowDialect::TensorFlowDialect(MLIRContext *context) - : Dialect(/*name=*/"tf", context) { + : Dialect(/*name=*/"tf", context, TypeID::get()) { addOperations< #define GET_OP_LIST -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc.inc" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.cc.inc" >(); addTypes< #define HANDLE_TF_TYPE(tftype, enumerant, name) tftype##Type, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index f37b71575f6..039ed1bc3a8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project @@ -35,6 +36,9 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -112,17 +116,6 @@ class TensorFlowDialect : public Dialect { static std::vector *additional_operation_hooks_; }; -// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose -// purpose is to catch bug on `tensorflow::mutex_lock`. We don't use -// `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and -// `tf.ConsumeMutexLock`) with getter methods named as `mutex_lock()`. Need to -// undefine here to avoid expanding the getter symbol as macro when including -// both mutex.h and this header file. -#undef mutex_lock - -#define GET_OP_CLASSES -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h.inc" - } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 7c6e6c672ae..5269bb82239 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -68,6 +68,51 @@ class TF_TensorListInitOp : TF_Op { }]; } +def TF_CaseOp : TF_Op<"Case", []> { + let summary = [{ +An n-way switch statement which calls a single branch function. + }]; + + let description = [{ +An n-way switch statement, implementing the following: + ``` + switch (branch_index) { + case 0: + output = branches[0](input); + break; + case 1: + output = branches[1](input); + break; + ... + case [[nbranches-1]]: + default: + output = branches[nbranches-1](input); + break; + } + ``` + }]; + + let arguments = (ins + I32Tensor:$branch_index, + Variadic:$input, + + Confined]>:$branches, + DefaultValuedAttr:$output_shapes, + + // Used to map StatelessCase and Case to a common op. + DefaultValuedAttr:$is_stateless + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; + + let hasCanonicalizer = 1; +} + // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with // its type encoding the tensor's shape and data type. def TF_ConstOp : TF_Op<"Const", [ConstantLike, NoSideEffect, @@ -225,10 +270,25 @@ else_branch: A function that takes 'inputs' and returns a list of TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; + TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + // Get the then branch function. + FuncOp then_func() { + return SymbolTable::lookupNearestSymbolFrom(*this, then_branch()); + } + + // Get the else branch function. + FuncOp else_func() { + return SymbolTable::lookupNearestSymbolFrom(*this, else_branch()); + } + }]; } def TF_YieldOp : TF_Op<"Yield", @@ -331,8 +391,8 @@ def TF_LegacyCallOp : TF_Op<"LegacyCall", within the same symbol scope as the call and is mapped to a GraphDef node with the function name as the op name. Unlike a PartitionedCall which represents asynchronously executing a function across multiple devices, a - LegacyCall represents a function call with the only attribute - _diable_call_shape_inference. + LegacyCall ignores specification for ops in the attached function and + instead executes it on the device assigned to this op. }]; let arguments = (ins @@ -351,8 +411,11 @@ def TF_LegacyCallOp : TF_Op<"LegacyCall", operand_range getArgOperands() { return args(); } // Returns the callee of this operation. - CallInterfaceCallable getCallableForCallee() { - return getAttrOfType("f"); + CallInterfaceCallable getCallableForCallee() { return fAttr(); } + + // returns the callee of this operation. + FuncOp func() { + return SymbolTable::lookupNearestSymbolFrom(*this, f()); } }]; } @@ -469,8 +532,11 @@ underlying graph, and executes each of the partitioned subgraphs as a function. operand_range getArgOperands() { return args(); } // Returns the callee of this operation. - CallInterfaceCallable getCallableForCallee() { - return getAttrOfType("f"); + CallInterfaceCallable getCallableForCallee() { return fAttr(); } + + // returns the callee of this operation. + FuncOp func() { + return SymbolTable::lookupNearestSymbolFrom(*this, f()); } }]; @@ -575,8 +641,11 @@ underlying graph, and executes each of the partitioned subgraphs as a function. operand_range getArgOperands() { return args(); } // Returns the callee of this operation. - CallInterfaceCallable getCallableForCallee() { - return getAttrOfType("f"); + CallInterfaceCallable getCallableForCallee() { return fAttr(); } + + // returns the callee of this operation. + FuncOp func() { + return SymbolTable::lookupNearestSymbolFrom(*this, f()); } }]; @@ -610,7 +679,6 @@ body: A function that takes a list of tensors and returns another FlatSymbolRefAttr:$cond, FlatSymbolRefAttr:$body, - DefaultValuedAttr:$output_shapes, DefaultValuedAttr:$parallel_iterations, // Used to map StatelessWhile and While op defined in TensorFlow to a common @@ -623,10 +691,24 @@ body: A function that takes a list of tensors and returns another ); TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; let verifier = [{ return Verify(*this); }]; + let hasCanonicalizer = 1; + + let extraClassDeclaration = [{ + // Get the condition function. + FuncOp cond_func() { + return SymbolTable::lookupNearestSymbolFrom(*this, cond()); + } + + // Get the body function. + FuncOp body_func() { + return SymbolTable::lookupNearestSymbolFrom(*this, body()); + } + }]; } def TL_WhileRegionOp : TF_Op<"WhileRegion", @@ -1068,31 +1150,6 @@ def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> { TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>; } -// TODO(b/156507832): Move tf.InplaceUpdate to tf_generated_ops.td once -// autogenerated op def matches. -def TF_InplaceUpdateOp : TF_Op<"InplaceUpdate", [NoSideEffect]> { - let summary = "Updates specified rows 'i' with values 'v'."; - - let description = [{ -Computes `x[i, :] = v; return x`. - -Originally this function is mutative however for compilation we make this -operation create / operate on a copy of `x`. - }]; - - let arguments = (ins - TF_Tensor:$x, - I32Tensor:$i, - TF_Tensor:$v - ); - - let results = (outs - TF_Tensor:$y - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; -} - def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the Bessel i0e function of `x` element-wise."; @@ -1188,12 +1245,143 @@ def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> { operand_range getArgOperands() { return args(); } // Returns the callee of this operation. - CallInterfaceCallable getCallableForCallee() { - return getAttrOfType("f"); + CallInterfaceCallable getCallableForCallee() { return fAttr(); } + + // returns the callee of this operation. + FuncOp func() { + return SymbolTable::lookupNearestSymbolFrom(*this, f()); } }]; let verifier = [{ return VerifyPartitionedCall(*this); }]; } +class TF_FusedBatchNormOpBase : TF_Op { + let summary = "Batch normalization."; + + let description = [{ +Note that the size of 4D Tensors are defined by either "NHWC" or "NCHW". +The size of 1D Tensors matches the dimension C of the 4D Tensors. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$offset, + F32Tensor:$mean, + F32Tensor:$variance, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {0}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + + // TF_LayoutSensitiveInterface: + StringRef GetOptimalLayout(const RuntimeDevices& devices); + LogicalResult UpdateDataFormat(StringRef data_format); + }]; +} + +def TF_FusedBatchNormV2Op : TF_FusedBatchNormOpBase<"FusedBatchNormV2"> { + let results = (outs + TensorOf<[BF16, F16, F32]>:$y, + F32Tensor:$batch_mean, + F32Tensor:$batch_variance, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2 + ); +} + +def TF_FusedBatchNormV3Op : TF_FusedBatchNormOpBase<"FusedBatchNormV3"> { + let results = (outs + TensorOf<[BF16, F16, F32]>:$y, + F32Tensor:$batch_mean, + F32Tensor:$batch_variance, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2, + F32Tensor:$reserve_space_3 + ); +} + +def TF_BatchFunctionOp : TF_Op<"BatchFunction", [AttrSizedOperandSegments]> { + let summary = [{ +Batches all the inputs tensors to the computation done by the function. + }]; + + let description = [{ +So, for example, in the following code + + ```python + + # This input will be captured. + y = tf.placeholder_with_default(1.0, shape=[]) + + @tf.Defun(tf.float32) + def computation(a): + return tf.matmul(a, a) + y + + b = gen_batch_ops.batch_function( + f=computation + in_tensors=[a], + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg], + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + allowed_batch_sizes=[3, 10], + batching_queue="") + ``` + +If more than one session.run call is simultaneously trying to compute `b` +the values of `a` will be gathered, non-deterministically concatenated +along the first axis, and only one thread will run the computation. + +Assumes that all arguments of the function are Tensors which will be batched +along their first dimension. + +Arguments that are captured, are not batched. The session.run call which does +the concatenation, will use the values of the captured tensors available to it. +Therefore, typical uses of captured tensors should involve values which remain +unchanged across session.run calls. Inference is a good example of this. + +SparseTensor is not supported. The return value of the decorated function +must be a Tensor or a list/tuple of Tensors. + }]; + + let arguments = (ins + Variadic:$in_tensors, + Variadic:$captured_tensors, + + SymbolRefAttr:$f, + I64Attr:$num_batch_threads, + I64Attr:$max_batch_size, + I64Attr:$batch_timeout_micros, + DefaultValuedAttr:$max_enqueued_batches, + DefaultValuedAttr:$allowed_batch_sizes, + StrAttr:$container, + StrAttr:$shared_name, + StrAttr:$batching_queue, + DefaultValuedAttr:$enable_large_batch_splitting, + I32ElementsAttr:$operand_segment_sizes + ); + + let results = (outs + Variadic:$out_tensors + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedOperandTypeListAttr Tcaptured = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; +} + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc new file mode 100644 index 00000000000..1a730a38618 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -0,0 +1,2083 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace mlir { +namespace TF { + +namespace { +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" +} // namespace + +//===----------------------------------------------------------------------===// +// AddOp +//===----------------------------------------------------------------------===// + +void AddOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// AddNOp +//===----------------------------------------------------------------------===// + +OpFoldResult AddNOp::fold(ArrayRef operands) { + if (operands.size() == 1) return *inputs().begin(); + return {}; +} + +//===----------------------------------------------------------------------===// +// AddV2Op +//===----------------------------------------------------------------------===// + +void AddV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult AddV2Op::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + +//===----------------------------------------------------------------------===// +// AllOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AllOp op) { + return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), + op.getLoc()); +} + +//===----------------------------------------------------------------------===// +// AnyOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AnyOp op) { + return VerifyReductionInputAndDims(op.input(), op.reduction_indices(), + op.getLoc()); +} + +//===----------------------------------------------------------------------===// +// AssertOp +//===----------------------------------------------------------------------===// + +namespace { + +// Removes Assert with constant true predicate. +struct AssertWithTrue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssertOp op, + PatternRewriter &rewriter) const override { + ElementsAttr cst; + if (matchPattern(op.condition(), m_Constant(&cst))) { + if (cst.getValue({}).getValue()) { + rewriter.eraseOp(op); + return success(); + } + } + return failure(); + } +}; +} // namespace + +void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BatchMatMulOp +//===----------------------------------------------------------------------===// + +void BatchMatMulOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BatchMatMulV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(BatchMatMulV2Op op) { + if (!HasRankAtLeast(op.x(), 2)) { + return op.emitOpError("requires lhs operand to have rank at least two"); + } + if (!HasRankAtLeast(op.y(), 2)) { + return op.emitOpError("requires rhs operand to have rank at least two"); + } + return success(); +} + +void BatchMatMulV2Op::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BatchToSpaceOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(BatchToSpaceOp op) { + // Op already has a constraint that block_size >= 2. + int64_t block_size = op.block_size().getSExtValue(); + + llvm::SmallVector input_shape(4, ShapedType::kDynamicSize); + auto input_type = op.input().getType().cast(); + if (input_type.hasRank()) { + if (input_type.getRank() != 4) + return op.emitOpError() + << "requires input to be a 4D tensor, but got " << input_type; + + int64_t input_batch = input_type.getDimSize(0); + if (input_batch != ShapedType::kDynamicSize && + input_batch % (block_size * block_size) != 0) { + return op.emitOpError() + << "requires input batch (dimension 0) to be evenly divisible " + "by (block_size * block_size), but got input batch " + << input_batch << " and block_size " << block_size; + } + + input_shape.assign(input_type.getShape().begin(), + input_type.getShape().end()); + } + + auto crops_type = op.crops().getType().cast(); + if (crops_type.hasRank()) { + if (crops_type.getRank() != 2) + return op.emitOpError() + << "requires crops to be a 2D tensor, but got " << crops_type; + + auto dim_of_size = [&](int64_t dim, int64_t size) { + if (crops_type.isDynamicDim(dim)) return true; + return crops_type.getDimSize(dim) == size; + }; + if (!dim_of_size(0, 2) || !dim_of_size(1, 2)) + return op.emitOpError() + << "requires crops to be a tensor<2x2>, but got " << crops_type; + } + + DenseIntElementsAttr crops_attr; + // Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]], + // and flattened as [crop_top, crop_bottom, crop_left, crop_right] + llvm::SmallVector crops_values; + if (matchPattern(op.crops(), m_Constant(&crops_attr))) { + assert(crops_attr.getNumElements() == 4 && + "tf.BatchToSpace crops must have 4 elements"); + + auto crops_range = crops_attr.getIntValues(); + for (const auto &crops_value : crops_range) { + int64_t crops_value_int = crops_value.getSExtValue(); + if (crops_value_int < 0) + return op.emitOpError() + << "requires all crop values to be nonnegative, but got " + << crops_attr; + + crops_values.push_back(crops_value_int); + } + } + + auto output_type = op.output().getType().cast(); + if (output_type.hasRank()) { + if (output_type.getRank() != 4) + return op.emitOpError() + << "requires output to be a 4D tensor, but got " << output_type; + + auto static_dims = [](int64_t dim_a, int64_t dim_b) { + return dim_a != ShapedType::kDynamicSize && + dim_b != ShapedType::kDynamicSize; + }; + + auto output_shape = output_type.getShape(); + + // output batch = input batch / (block_size * block_size). + int64_t input_batch = input_shape[0]; + int64_t output_batch = output_shape[0]; + if (static_dims(input_batch, output_batch) && + (output_batch * block_size * block_size) != input_batch) + return op.emitOpError() + << "requires output batch (dimension 0) to be equal to input " + "batch (dimension 0) / (block_size * block_size), but got " + "output batch " + << output_batch << ", input batch " << input_batch + << ", and block_size " << block_size; + + auto check_spatial_dim = [&](int64_t spatial_dim_index, + llvm::StringRef dim_name, + llvm::StringRef crop_a_name, + llvm::StringRef crop_b_name) -> LogicalResult { + int64_t input_dim = input_shape[spatial_dim_index]; + int64_t output_dim = output_shape[spatial_dim_index]; + if (!static_dims(input_dim, output_dim)) return success(); + + int64_t input_dim_pad = input_dim * block_size; + // If crops are unknown, the maximum output spatial dim size is input + // spatial dim size * block_size, as crops can be minimum 0. + if (crops_values.empty() && output_dim > input_dim * block_size) + return op.emitOpError() + << "requires output " << dim_name << " (dimension " + << spatial_dim_index << ") to be less than or equal to input " + << dim_name << " (dimension " << spatial_dim_index + << ") * block_size, but got output " << dim_name << " " + << output_dim << ", input " << dim_name << " " << input_dim + << ", and block_size " << block_size; + + if (!crops_values.empty()) { + // output spatial dim = input spatial dim * block_size - crops. + int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)]; + int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1]; + if (output_dim != input_dim_pad - crop_a - crop_b) + return op.emitOpError() + << "requires output " << dim_name << " (dimension " + << spatial_dim_index << ") to be equal to input " << dim_name + << " (dimension " << spatial_dim_index << ") * block_size - " + << crop_a_name << " - " << crop_b_name << ", but got output " + << dim_name << " " << output_dim << ", input " << dim_name + << " " << input_dim << ", " << crop_a_name << " " << crop_a + << ", " << crop_b_name << " " << crop_b << ", and block_size " + << block_size; + } + + return success(); + }; + + if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) || + failed(check_spatial_dim(2, "width", "crop_left", "crop_right"))) + return failure(); + + int64_t input_depth = input_shape[3]; + int64_t output_depth = output_shape[3]; + if (static_dims(input_depth, output_depth) && output_depth != input_depth) + return op.emitOpError() + << "requires output depth (dimension 3) to be equal to input " + "depth (dimension 3), but got output depth " + << output_depth << " and input depth " << input_depth; + } + + return success(); +} + +void BatchToSpaceOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BiasAddOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// * the value and bias operands have valid ranks or are unranked. +// * Channel dimension of the value operand and length of bias matches if they +// are not unknown. +// +static LogicalResult Verify(BiasAddOp op) { + absl::string_view data_format(op.data_format().data(), + op.data_format().size()); + tensorflow::TensorFormat format; + bool is_valid = FormatFromString(data_format, &format); + DCHECK(is_valid) << data_format; + if (format == tensorflow::TensorFormat::FORMAT_NHWC) { + if (!HasRankAtLeast(op.value(), 2)) + return op.emitOpError( + "requires value operand to have rank at least two with `NHWC` data " + "format"); + } else { + // Op definition requires data_format to be either NHWC or NCHW. + DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW); + if (!HasRankAtLeast(op.value(), 3)) + return op.emitOpError( + "requires value operand to have rank at least three with `NCHW` data " + "format"); + } + + if (!IsOfRankOrUnranked(op.bias(), 1)) + return op.emitOpError("requires bias operand to have rank exactly one"); + + RankedTensorType value_ty = op.value().getType().dyn_cast(); + RankedTensorType bias_ty = op.bias().getType().dyn_cast(); + if (!bias_ty || !value_ty) return success(); + + int64_t feature_dim_idx = + tensorflow::GetTensorFeatureDimIndex(value_ty.getRank(), format); + int64_t feature_dim = value_ty.getDimSize(feature_dim_idx); + int64_t bias_len = bias_ty.getDimSize(0); + if (feature_dim != -1 && bias_len != -1 && feature_dim != bias_len) { + return op.emitOpError() + << "requires channel dimension and feature dimension to match; " + "found " + << feature_dim << " and " << bias_len << ", respectively"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// BiasAddGradOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// * the out_backprop operands have valid ranks or are unranked. +// +static LogicalResult Verify(BiasAddGradOp op) { + absl::string_view data_format(op.data_format().data(), + op.data_format().size()); + tensorflow::TensorFormat format; + bool is_valid = FormatFromString(data_format, &format); + DCHECK(is_valid) << data_format; + if (format == tensorflow::TensorFormat::FORMAT_NHWC) { + if (!HasRankAtLeast(op.out_backprop(), 2)) + return op.emitOpError( + "requires out_backprop operand to have rank at least two with `NHWC` " + "data format"); + } else { + // Op definition requires data_format to be either NHWC or NCHW. + DCHECK_EQ(format, tensorflow::TensorFormat::FORMAT_NCHW); + if (!HasRankAtLeast(op.out_backprop(), 3)) + return op.emitOpError( + "requires out_backprop operand to have rank at least three with " + "`NCHW` data format"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// BiasAddV1Op +//===----------------------------------------------------------------------===// + +void BiasAddV1Op::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +void BitcastOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// BroadcastToOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(BroadcastToOp op) { + // TODO(antiagainst): check that + // * The 'shape' input is an 1-D int tensor. + // * Each dimension pair of the source and target shapes are either equal + // or one of them is one. + return success(); +} + +//===----------------------------------------------------------------------===// +// CaseOp +//===----------------------------------------------------------------------===// + +class FoldConstantCaseOp : public OpRewritePattern { + public: + explicit FoldConstantCaseOp(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(TF::CaseOp op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult FoldConstantCaseOp::matchAndRewrite( + TF::CaseOp op, PatternRewriter &rewriter) const { + // Extract the constant cond value. + DenseIntElementsAttr branch; + if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure(); + + // Only attempt to fold scalar valued case statements. + // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. + if (!branch.getType().cast().getShape().empty()) + return failure(); + + int index = *branch.getValues().begin(); + // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. + if (index >= op.branches().size()) return failure(); + + auto func = op.branches()[index].cast(); + auto empty = rewriter.getStringAttr(""); + auto call_op = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, + /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); + PropagateDeviceAndInternalAttrs(op.getOperation(), call_op); + rewriter.replaceOp(op, call_op.getResults()); + return success(); +} + +void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +OpFoldResult CastOp::fold(ArrayRef operands) { + // Cast with the same type is a no-op. + Value operand = getOperand(); + if (getType() == operand.getType()) return operand; + return {}; +} + +//===----------------------------------------------------------------------===// +// ConcatOp and ConcatV2Op +//===----------------------------------------------------------------------===// + +template ::value>::type * = nullptr> +static LogicalResult Verify(OpT op) { + // TODO(hinsu): Convert variadic length attributes to derived attributes. + Operation::operand_range values = op.values(); + + int axis_idx = std::is_same() ? 0 : 1; + Value axis = *op.getODSOperands(axis_idx).begin(); + if (!HasRankAtMost(axis, 1)) { + return op.emitOpError( + "requires axis to be of scalar type (or vector type for older " + "versions)"); + } + + return VerifyTypesCompatibility(values, + /*mask_one_dim=*/true, op.getOperation()); +} + +void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +namespace { + +// Hoist coefficient-wise unary operation out of the Concat op: +// +// %0 = "tf.Log1p"(%arg_0) +// %1 = "tf.Log1p"(%arg_1) +// ... +// %n = "tf.Log1p"(%arg_n) +// %m = "tf.ConcatV2"(%0, %1, ..., %n, %axis) +// +// Rewrite it to: +// +// %0 = "tf.ConcatV2"(%arg_0, %arg_1, ..., %arg_n, %axis) +// %1 = "tf.Log1p"(%0) +class HoistCwiseUnaryOutOfConcat : public OpRewritePattern { + public: + explicit HoistCwiseUnaryOutOfConcat(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(TF::ConcatV2Op op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult HoistCwiseUnaryOutOfConcat::matchAndRewrite( + TF::ConcatV2Op op, PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + // All concat operands must be defined by ops. + Operation *first_arg_op = op.values().front().getDefiningOp(); + if (first_arg_op == nullptr) return failure(); + + // All concat operands must be produced by the coeff-wise unary operation. + if (!first_arg_op->hasTrait()) return failure(); + + // All concat operands must be defined by the op of same kind. + bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool { + Operation *arg_op = arg.getDefiningOp(); + return arg_op && arg_op->getName() == first_arg_op->getName(); + }); + if (!args_same_op) return failure(); + + // Collect unary operations operands. + auto unary_operands = llvm::map_range(op.values(), [](Value arg) -> Value { + return arg.getDefiningOp()->getOperand(0); + }); + SmallVector unary_ops_args(unary_operands); + + // Concatenate unary ops operands. + auto concat_unary_operands = + rewriter.create(loc, op.getType(), unary_ops_args, op.axis()); + + // Replace original concat with an unary op. + OperationState new_unary_op_state(loc, first_arg_op->getName().getStringRef(), + concat_unary_operands.getResult(), + op.getResult().getType(), + ArrayRef()); + Operation *new_unary_op = rewriter.createOperation(new_unary_op_state); + + rewriter.replaceOp(op, new_unary_op->getResults()); + + return success(); +} + +// Hoist coefficient-wise binary operation out of the Concat op: +// +// %0 = tf.Mul(%lhs_0, %rhs_0) +// %1 = tf.Mul(%lhs_1, %rhs_1) +// ... +// %n = tf.Mul(%lhs_n, %rhs_n) +// %m = tf.ConcatV2(%0, %1, ..., %n, %axis) +// +// Rewrite it to: +// +// %0 = tf.ConcatV2(%lhs0, %lhs1, ..., %lhs_n, %lhs_concat_axis) +// %1 = tf.ConcatV2(%rhs0, %rhs1, ..., %rhs_n, %rhs_concat_axis) +// %2 = tf.Mul(%0, %1) +// +// Because coefficient-wise binary operations support implicit broadcasting, we +// should be very careful with this optimization, and do not accidentally +// produce incorrect concat operations. +class HoistCwiseBinaryOutOfConcat : public OpRewritePattern { + public: + explicit HoistCwiseBinaryOutOfConcat(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(TF::ConcatV2Op op, + PatternRewriter &rewriter) const override; + + private: + struct HoistParams { + SmallVector lhs_args; + SmallVector rhs_args; + int64_t lhs_axis; + int64_t rhs_axis; + Type lhs_concat_type; + Type rhs_concat_type; + }; + + // Returns parameters of a binary op hoisting out of concatenation if all of + // the operands are in one of the compatible configurations. + Optional GetHoistParams(TF::ConcatV2Op op, int64_t axis) const; +}; + +LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( + TF::ConcatV2Op op, PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + // Axis must be a constant scalar value. + DenseIntElementsAttr axis_attr; + if (!matchPattern(op.axis(), m_Constant(&axis_attr))) return failure(); + if (axis_attr.getNumElements() != 1) return failure(); + int64_t axis = + axis_attr.getSplatValue().getValue().getSExtValue(); + + // All concat operands must be defined by ops. + Operation *first_arg_op = op.values().front().getDefiningOp(); + if (first_arg_op == nullptr) return failure(); + + // All concat operands must be produced by the coeff-wise binary operation. + if (!first_arg_op->hasTrait()) return failure(); + + // All concat operands must be defined by the op of same kind. + bool args_same_op = llvm::all_of(op.values(), [&](Value arg) -> bool { + Operation *arg_op = arg.getDefiningOp(); + return arg_op && arg_op->getName() == first_arg_op->getName(); + }); + if (!args_same_op) return failure(); + + // Compute binary operands hoist parameters. + auto hoist_params = GetHoistParams(op, axis); + if (!hoist_params.hasValue()) return failure(); + + // New lhs and rhs concatenation axis. + auto axis_type = mlir::RankedTensorType::get({}, rewriter.getIntegerType(64)); + auto lhs_axis = rewriter.create( + loc, DenseIntElementsAttr::get(axis_type, hoist_params->lhs_axis)); + auto rhs_axis = rewriter.create( + loc, DenseIntElementsAttr::get(axis_type, hoist_params->rhs_axis)); + + // Concatenate binary ops operands on the new axis. + auto lhs_concat = rewriter.create( + loc, hoist_params->lhs_concat_type, hoist_params->lhs_args, lhs_axis); + auto rhs_concat = rewriter.create( + loc, hoist_params->rhs_concat_type, hoist_params->rhs_args, rhs_axis); + + // Replace original concat with a binary op. + OperationState new_binary_op_state( + loc, first_arg_op->getName().getStringRef(), + {lhs_concat.getResult(), rhs_concat.getResult()}, + op.getResult().getType(), ArrayRef()); + Operation *new_binary_op = rewriter.createOperation(new_binary_op_state); + + rewriter.replaceOp(op, new_binary_op->getResults()); + + return success(); +} + +Optional +HoistCwiseBinaryOutOfConcat::GetHoistParams(TF::ConcatV2Op op, + int64_t axis) const { + // Collects lhs or rhs arguments of concat op operands. + auto args = [&](int operand_idx) -> SmallVector { + auto range = llvm::map_range(op.values(), [&](Value arg) { + return arg.getDefiningOp()->getOperand(operand_idx); + }); + return {range.begin(), range.end()}; + }; + + // Returns true if all binary ops operands at `operand_idx` index are tensors + // of `axis + 1` rank and axis dim has size `1`. + auto is_all_tensors = [&](int operand_idx, int axis) -> bool { + return llvm::all_of(op.values(), [&](Value arg) -> bool { + auto operand = arg.getDefiningOp()->getOperand(operand_idx); + auto ranked = operand.getType().dyn_cast(); + return ranked && ranked.getRank() == (axis + 1) && + ranked.getShape()[axis] == 1; + }); + }; + + // Returns true if all binary ops operands at `operand_idx` index are scalars. + auto is_all_scalars = [&](int operand_idx) -> bool { + return llvm::all_of(op.values(), [&](Value arg) -> bool { + auto operand = arg.getDefiningOp()->getOperand(operand_idx); + auto ranked = operand.getType().dyn_cast(); + return ranked && ranked.hasRank() && ranked.getRank() == 0; + }); + }; + + // Concat result type must be a ranked tensor. + auto ranked = op.getType().dyn_cast(); + if (!ranked) return None; + + // TODO(ezhulenev): Add support for more valid concat patterns. + + // Tensor + Scalar: [..., 1] + [] <- scalar + // ^ + // \- axis is the innermost dimension. + // + // Concatenate tensor arguments on the same axis as the original operation, + // and concatenate scalars into the vector. + if (is_all_tensors(0, axis) && is_all_scalars(1)) { + std::array rhs_dims{static_cast(op.values().size())}; + auto rhs_type = RankedTensorType::get(rhs_dims, ranked.getElementType()); + return HoistParams{args(0), args(1), axis, 0, op.getType(), rhs_type}; + } + + return None; +} + +} // namespace + +void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert( + context); +} + +//===----------------------------------------------------------------------===// +// ConcatOffsetOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ConcatOffsetOp op) { + if (op.N() < 2) + return op.emitOpError() << "requires N to be at least 2, got " << op.N(); + + if (op.shape().size() != op.offset().size()) + return op.emitOpError() + << "requires sizes of shapes and offsets to be the same, got sizes " + << op.shape().size() << " and " << op.offset().size(); + + auto ranked_dim = op.concat_dim().getType().dyn_cast(); + if (ranked_dim && ranked_dim.getRank() != 0) + return op.emitOpError() + << "requires concat_dim to be a scalar, got tensor of rank " + << ranked_dim.getRank(); + + int64_t num_dims = -1; + for (auto shape_offset_idx : + llvm::enumerate(llvm::zip(op.shape(), op.offset()))) { + Value shape = std::get<0>(shape_offset_idx.value()); + Value offset = std::get<1>(shape_offset_idx.value()); + const size_t idx = shape_offset_idx.index(); + + if (failed(verifyCompatibleShape(shape.getType(), offset.getType()))) + return op.emitOpError() << "requires operand and result " << idx + << " to have compatible shapes"; + + auto ranked_shape = shape.getType().dyn_cast(); + if (!ranked_shape) continue; + + if (ranked_shape.getRank() != 1) + return op.emitOpError() << "requires shape tensor operand " << idx + << " to be of rank 1, got tensor of rank " + << ranked_shape.getRank(); + + if (!ranked_shape.hasStaticShape()) continue; + + int64_t ranked_shape_dim = ranked_shape.getDimSize(0); + if (num_dims == -1) + num_dims = ranked_shape_dim; + else if (ranked_shape_dim != num_dims) + return op.emitOpError() + << "requires shape tensor (rank 1) operand " << idx + << " to be of length " << num_dims + << ", got tensor (rank 1) of length " << ranked_shape_dim; + } + + return success(); +} + +LogicalResult ConcatOffsetOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + // ConcatOffset must have its first operand be concat_dim and at least two + // shape tensors in variadic shapes operand. + if (operands.size() < 3) return failure(); + + // Check concat_dim is a scalar. + auto concat_dim_attr = operands[0].dyn_cast_or_null(); + if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0) + return failure(); + + llvm::SmallVector shapes; + shapes.reserve(operands.size() - 1); + for (Attribute shape : llvm::drop_begin(operands, 1)) + if (auto shape_attr = shape.dyn_cast_or_null()) + shapes.push_back(shape_attr); + else + return failure(); + + // Check all shapes are vectors of the same length. + if (shapes.front().getType().getRank() != 1) return success(); + const int64_t num_dims = shapes.front().getNumElements(); + for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) + if (shape.getType().getRank() != 1 || shape.getNumElements() != num_dims) + return failure(); + + // Check concat_dim is within [-num_dims, num_dims). + int32_t concat_dim = (*concat_dim_attr.getValues().begin()); + if (concat_dim < 0) concat_dim += num_dims; + if (concat_dim >= num_dims || concat_dim < 0) return failure(); + + // Check all elements besides at concat_dim match across all shape tensors. + SmallVector shape0; + shape0.reserve(num_dims); + for (int32_t dim : shapes.front().getValues()) shape0.push_back(dim); + + for (DenseIntElementsAttr shape : llvm::drop_begin(shapes, 1)) { + for (auto dims_and_idx : llvm::enumerate(llvm::zip(shape0, shape))) { + if (dims_and_idx.index() == concat_dim) continue; + + if (std::get<0>(dims_and_idx.value()) != + std::get<1>(dims_and_idx.value()).getSExtValue()) + return failure(); + } + } + + // Compute an exclusive cumulative sum of elements at concat_dim. + results.reserve(shapes.size()); + SmallVector cumulative_sum(num_dims, 0); + RankedTensorType offset_type = + RankedTensorType::get({num_dims}, IntegerType::get(32, getContext())); + for (DenseIntElementsAttr shape : shapes) { + results.push_back(DenseIntElementsAttr::get(offset_type, cumulative_sum)); + cumulative_sum[concat_dim] += shape.getValue(concat_dim); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ConjOp +//===----------------------------------------------------------------------===// + +void ConjOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// ConstOp +//===----------------------------------------------------------------------===// + +OpFoldResult ConstOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + + // Return the held attribute value. + return value(); +} + +// Builds a constant op with the specified attribute `value`. The result +// op's type is deduced from `value`; if `value` is of scalar type, +// wraps it up with a tensor type of empty shape. +// TODO(jpienaar): This one differs from the autogenerated one as it takes an +// attribute but always creates an ElementsAttr internally. +void ConstOp::build(OpBuilder &builder, OperationState &result, + Attribute value) { + ShapedType type; + if (auto elem_attr = value.dyn_cast()) { + return ConstOp::build(builder, result, elem_attr); + } else if (value.isa()) { + // All TensorFlow types must be tensor types. In the build() method, + // we want to provide more flexibility by allowing attributes of scalar + // types. But we need to wrap it up with ElementsAttr to construct + // valid TensorFlow constants. + type = RankedTensorType::get(/*shape=*/{}, value.getType()); + return ConstOp::build(builder, result, DenseElementsAttr::get(type, value)); + } + // TODO(jpienaar): support other TensorFlow specific types. + llvm_unreachable("unsupported attribute type for building tf.Const"); +} + +void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, + Attribute value) { + // Handle the case where the type and value are already tensors. + if (type.isa() && value.isa()) { + result.addTypes(type); + result.addAttribute("value", value); + return; + } + + // Otherwise, default to the attribute builder. + ConstOp::build(builder, result, value); + assert(type == result.types[0] && "type mismatch in construction"); +} + +LogicalResult ConstOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + auto value = attributes.get("value"); + if (!value) return emitOptionalError(location, "missing attribute 'value'"); + if (auto elem_attr = value.dyn_cast()) { + inferredReturnTypes.assign({elem_attr.getType()}); + return success(); + } + return emitOptionalError(location, + "attribute 'value' failed to satisfy constraint: " + "constant vector/tensor"); +} + +//===----------------------------------------------------------------------===// +// Conv2DOp and Conv3DOp +//===----------------------------------------------------------------------===// + +template +static LogicalResult VerifyConvOpAttributes(OpT op, int num_dims) { + if (!IsOfRankOrUnranked(op.getResult(), num_dims)) + return op.emitOpError() + << "requires result to be " << num_dims << "D tensor"; + + auto is_not_positive = [](Attribute val) { + return val.cast().getValue().getSExtValue() <= 0; + }; + + int64_t strides_size = op.strides().size(); + if (strides_size != num_dims) + return op.emitOpError() << "requires strides attribute length to be " + << num_dims << "; actual length " << strides_size; + if (llvm::any_of(op.strides().getValue(), is_not_positive)) + return op.emitOpError("requires positive strides"); + + int64_t dilations_size = op.strides().size(); + if (op.dilations().size() != num_dims) + return op.emitOpError() << "requires dilations attribute length to be " + << num_dims << "; actual length " << dilations_size; + if (llvm::any_of(op.dilations().getValue(), is_not_positive)) + return op.emitOpError("requires positive dilations"); + + return success(); +} + +// Verifies that, +// * Ranks of operands and result are valid +// * Number of input channels is divisible by the number of filter input +// channels +// * Length of explicit_paddings attribute is valid and has non negative +// elements +// * strides and dilations attributes have positive elements +template ::value>::type * = nullptr> +static LogicalResult Verify(OpT op) { + int num_spatial_dims = std::is_same() ? 2 : 3; + int num_dims = 2 + num_spatial_dims; + + if (!IsOfRankOrUnranked(op.input(), num_dims) || + !IsOfRankOrUnranked(op.filter(), num_dims)) + return op.emitOpError() + << "requires operands to be " << num_dims << "D tensor"; + + // EXPLICIT padding mode and the associated attribute is limited to Conv2D. + // So, fetch attribute by string instead of the op.explicit_paddings() + // attribute getter. + if (op.padding() == "EXPLICIT") { + auto paddings = op.template getAttrOfType("explicit_paddings"); + if (!paddings) + return op.emitOpError() << "requires attribute 'explicit_paddings' with " + "'EXPLICIT' padding mode"; + + int64_t paddings_size = paddings.size(); + int64_t expected_size = 2 * num_dims; + + if (paddings_size != expected_size) + return op.emitOpError() + << "requires explicit_paddings attribute length to be " + << expected_size << "; actual length " << paddings_size; + + auto is_negative = [](Attribute val) { + return val.cast().getValue().getSExtValue() < 0; + }; + if (llvm::any_of(paddings.getValue(), is_negative)) + return op.emitOpError("requires non negative explicit paddings"); + } + + LogicalResult verify_result = VerifyConvOpAttributes(op, num_dims); + if (failed(verify_result)) { + return verify_result; + } + + int64_t input_channels = -1; + if (auto ty = op.input().getType().template dyn_cast()) { + absl::string_view data_format(op.data_format().data(), + op.data_format().size()); + tensorflow::TensorFormat format; + auto is_valid = FormatFromString(data_format, &format); + DCHECK(is_valid) << data_format; + int idx = tensorflow::GetTensorFeatureDimIndex(num_dims, format); + input_channels = ty.getDimSize(idx); + } + + int64_t filter_channels = -1; + if (auto ty = op.filter().getType().template dyn_cast()) { + int idx = tensorflow::GetFilterTensorInputChannelsDimIndex( + num_dims, tensorflow::FORMAT_HWIO); + filter_channels = ty.getDimSize(idx); + } + + if (input_channels != -1 && filter_channels != -1 && + input_channels % filter_channels != 0) + return op.emitOpError() + << "requires the number of input channels to be divisible by the " + "number of filter input channels; found " + << input_channels << " and " << filter_channels << ", respectively"; + + return success(); +} + +LogicalResult Conv2DOp::UpdateDataFormat(StringRef data_format) { + auto perm = GetDataFormatPermutation(this->data_format(), data_format); + if (perm.empty()) return failure(); + + // Update data_format attribute and result types. + if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); + + // Update convolution attributes. + setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + setAttr("strides", ShuffleArrayAttr(strides(), perm)); + setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + + return success(); +} + +StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Input must be a tensor. + auto input_ty = input().getType().dyn_cast(); + if (!input_ty) return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + const bool is_f16 = input_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // For f32/f16 data type decision depends on the filter size in spatial + // dimensions, for other data types we keep current data format. + if (!input_ty.getElementType().isF32() && !input_ty.getElementType().isF16()) + return data_format(); + + // Keep current data format if filter rank is unknown or not equal to 4. + auto filter_ty = filter().getType().dyn_cast(); + if (!filter_ty || filter_ty.getRank() != 4) return data_format(); + + const int64_t d0 = filter_ty.getDimSize(0); + const int64_t d1 = filter_ty.getDimSize(1); + + auto all_ones = [](ArrayAttr arr) -> bool { + return llvm::all_of(arr, [](Attribute attr) -> bool { + return attr.cast().getInt() == 1; + }); + }; + + // Convolutions with 1x1 filter and with strides and dilations all ones, can + // be computed as a GEMM in NHWC data format, and can be up to ~2x times + // faster than convolution in NCHW. + const bool one_by_one = d0 == 1 && d1 == 1; + const bool trivial_strides = all_ones(strides()); + const bool trivial_dilations = all_ones(dilations()); + + // TODO(ezhulenev): This might lead to excessive transposes in the final IR, + // if the ratio of 1x1 convolutions to regular convolutions is close to 1:1. + // Also FusedBatchNorm in training mode prefers NCHW data format. Check if all + // users can efficiently use NHWC data format? + if (one_by_one && trivial_strides && trivial_dilations) { + return "NHWC"; + } + + // If filter spatial dimensions are unknown or not 1x1 we prefer NCHW, because + // it's the fastest option on NVIDIA GPUs with cuDNN library support. + return "NCHW"; +} + +//===----------------------------------------------------------------------===// +// Conv2dBackpropFilterOp +//===----------------------------------------------------------------------===// + +LogicalResult Conv2DBackpropFilterOp::UpdateDataFormat(StringRef data_format) { + StringRef src_data_format = this->data_format(); + + auto perm = GetDataFormatPermutation(src_data_format, data_format); + if (perm.empty()) return failure(); + + // Update data_format attribute and result types. + if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); + + // Update convolution attributes. + setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + setAttr("strides", ShuffleArrayAttr(strides(), perm)); + setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + + // Permute filter sizes operand. + OpBuilder builder(getOperation()); + auto filter_sizes_permuted = builder.create( + getLoc(), filter_sizes(), StringAttr::get(src_data_format, getContext()), + StringAttr::get(data_format, getContext())); + setOperand(1, filter_sizes_permuted); + + return success(); +} + +StringRef Conv2DBackpropFilterOp::GetOptimalLayout( + const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Input must be a tensor. + auto input_ty = input().getType().dyn_cast(); + if (!input_ty) return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + const bool is_f16 = input_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // Otherwise always use "NCHW". + return "NCHW"; +} + +//===----------------------------------------------------------------------===// +// Conv2DBackpropInputOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(Conv2DBackpropInputOp op) { + int num_spatial_dims = 2; + int num_dims = 2 + num_spatial_dims; + + if (!IsOfRankOrUnranked(op.out_backprop(), num_dims) || + !IsOfRankOrUnranked(op.filter(), num_dims)) + return op.emitOpError() + << "requires operands to be " << num_dims << "D tensor"; + + LogicalResult verify_result = VerifyConvOpAttributes(op, num_dims); + if (failed(verify_result)) { + return verify_result; + } + + return success(); +} + +LogicalResult Conv2DBackpropInputOp::UpdateDataFormat(StringRef data_format) { + StringRef src_data_format = this->data_format(); + + auto perm = GetDataFormatPermutation(src_data_format, data_format); + if (perm.empty()) return failure(); + + // Update data_format attribute and result types. + if (failed(::mlir::TF::UpdateDataFormat(data_format, this))) return failure(); + + // Update convolution attributes. + setAttr("dilations", ShuffleArrayAttr(dilations(), perm)); + setAttr("strides", ShuffleArrayAttr(strides(), perm)); + setAttr("explicit_paddings", ShuffleArrayAttr(explicit_paddings(), perm, 2)); + + // Permute input sizes operand. + OpBuilder builder(getOperation()); + auto input_sizes_permuted = builder.create( + getLoc(), input_sizes(), StringAttr::get(src_data_format, getContext()), + StringAttr::get(data_format, getContext())); + setOperand(0, input_sizes_permuted); + + return success(); +} + +StringRef Conv2DBackpropInputOp::GetOptimalLayout( + const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // Filter must be a tensor. + auto filter_ty = filter().getType().dyn_cast(); + if (!filter_ty) return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + const bool is_f16 = filter_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // Otherwise always use "NCHW". + return "NCHW"; +} + +//===----------------------------------------------------------------------===// +// DataFormatVecPermuteOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DataFormatVecPermuteOp op) { + auto input_ty = op.x().getType().dyn_cast(); + if (!input_ty) return success(); + + int rank = input_ty.getRank(); + if (rank != 1 && rank != 2) + return op.emitOpError("requires input of rank 1 or 2"); + + if (rank == 1) { + int64_t dim0 = input_ty.getDimSize(0); + if (dim0 != ShapedType::kDynamicSize && dim0 != 4 && dim0 != 2) + return op.emitOpError("requires 1D input of size 4 or size 2"); + } + + if (rank == 2) { + int64_t dim0 = input_ty.getDimSize(0); + if (dim0 != ShapedType::kDynamicSize && dim0 != 4) + return op.emitOpError( + "requires first dimensions of 2D input to be of size 4"); + + int64_t dim1 = input_ty.getDimSize(1); + if (dim1 != ShapedType::kDynamicSize && dim1 != 2) + return op.emitOpError( + "requires second dimensions of 2D input to be of size 2"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// DivOp +//===----------------------------------------------------------------------===// + +void DivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult DivOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + +//===----------------------------------------------------------------------===// +// DynamicStitchOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DynamicStitchOp op) { + if (op.N() < 1) return op.emitOpError("requires attribute N with value >= 1"); + + if (RankedTensorType out_ty = op.getType().dyn_cast()) { + if (out_ty.getRank() == 0) { + return op.emitOpError("requires non scalar output"); + } + } + + llvm::SmallDenseSet index_values; + bool all_indices_const = true; + int32_t max_index = -1; + llvm::Optional> inferred_item_shape; + for (auto it : llvm::zip(op.indices(), op.data())) { + Value index = std::get<0>(it); + + DenseIntElementsAttr index_attr; + if (matchPattern(index, m_Constant(&index_attr))) { + for (int32_t index : index_attr.getValues()) { + if (index < 0) + return op.emitOpError() + << "requires non-negative index values; found " << index; + max_index = std::max(index, max_index); + index_values.insert(index); + } + } else { + all_indices_const = false; + } + + Value data = std::get<1>(it); + RankedTensorType index_ty = index.getType().dyn_cast(); + RankedTensorType data_ty = data.getType().dyn_cast(); + if (!index_ty || !data_ty) continue; + + int64_t index_rank = index_ty.getRank(); + ArrayRef data_shape = data_ty.getShape(); + ArrayRef index_shape = index_ty.getShape(); + if (failed(mlir::verifyCompatibleShape(index_shape, + data_shape.take_front(index_rank)))) + return op.emitOpError() << "requires shape of data with type " << data_ty + << " to have prefix matching with shape of the " + "corresponding index type " + << index_ty; + + ArrayRef item_shape = data_shape.drop_front(index_rank); + if (!inferred_item_shape) { + inferred_item_shape = llvm::to_vector<4>(item_shape); + continue; + } + + if (failed(mlir::verifyCompatibleShape(item_shape, *inferred_item_shape))) + return op.emitOpError() << "has inconsistent shaped data and index " + "pairs; inferred item shapes [" + << llvm::makeArrayRef(*inferred_item_shape) + << "] and [" << item_shape << "] don't match"; + for (int i = 0, e = item_shape.size(); i < e; ++i) { + int64_t &inferred_dim = (*inferred_item_shape)[i]; + int64_t dim = item_shape[i]; + if (ShapedType::isDynamic(inferred_dim)) inferred_dim = dim; + } + } + + // If all indices are constants, then verify that they cover all indices in + // the range [0, max_index] and the output type is legal. + if (all_indices_const) { + for (int32_t i = 0; i <= max_index; i++) { + if (!index_values.count(i)) + return op.emitOpError() << "missing index " << i; + } + + if (inferred_item_shape) { + SmallVector expected_shape; + expected_shape.push_back(max_index + 1); + expected_shape.append(inferred_item_shape->begin(), + inferred_item_shape->end()); + + auto out_ty = op.getType().cast(); + auto expected_out_ty = + RankedTensorType::get(expected_shape, out_ty.getElementType()); + + if (!AreCastCompatible({out_ty, expected_out_ty})) { + return op.emitOpError() << "has invalid output type; should be " + "compatible with inferred type " + << expected_out_ty; + } + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// EinsumOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// * Arity of the op is at most two. +// +// TODO(hinsu): Verify einsum equation attribute. +static LogicalResult Verify(EinsumOp op) { + if (op.N() > 2) { + return op.emitOpError("supports at most two operands"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// EmptyOp +//===----------------------------------------------------------------------===// + +OpFoldResult EmptyOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "empty op has one operand"); + + Attribute attr = operands.front(); + if (!attr) return {}; + + auto int_attr = attr.cast(); + SmallVector out_shape; + for (const auto val : int_attr.getValues()) { + out_shape.push_back(val); + } + + auto type = getResult().getType().cast(); + auto etype = type.getElementType(); + + // We can not fold if the result is not static. + if (!type.hasStaticShape()) return {}; + + if (auto float_type = etype.dyn_cast()) { + auto out_type = RankedTensorType::get(out_shape, float_type); + return DenseElementsAttr::get(out_type, + {APFloat(float_type.getFloatSemantics())}); + } + + if (auto int_type = etype.dyn_cast()) { + auto out_type = RankedTensorType::get(out_shape, etype); + APInt val(int_type.getWidth(), 0, int_type.getSignedness()); + return DenseElementsAttr::get(out_type, val); + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// EmptyTensorListOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(EmptyTensorListOp op) { + if (!IsOfRankOrUnranked(op.element_shape(), 0) && + !IsOfRankOrUnranked(op.element_shape(), 1)) { + return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); + } + + if (!IsOfRankOrUnranked(op.max_num_elements(), 0)) { + return op.emitOpError("requires max_num_elements operand to be 0D tensor"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// EqualOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(EqualOp op) { + // If we allow inputs to have incompatible type, then nothing to do. + if (!op.incompatible_shape_error()) return success(); + + // Otherwise, check inputs are broadcastable. + return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( + op.getOperation()); +} + +void EqualOp::build(OpBuilder &builder, OperationState &result, Value x, + Value y, BoolAttr incompatible_shape_error) { + auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, + incompatible_shape_error); + return build(builder, result, result_type, x, y, incompatible_shape_error); +} + +//===----------------------------------------------------------------------===// +// ExpandDimsOp +//===----------------------------------------------------------------------===// + +Type InferExpandDimsOpType(Value input, Value dim) { + Type element_ty = input.getType().cast().getElementType(); + auto unranked_ty = UnrankedTensorType::get(element_ty); + + auto input_ty = input.getType().dyn_cast(); + if (!input_ty) return unranked_ty; + + DenseIntElementsAttr dim_attr; + if (!matchPattern(dim, m_Constant(&dim_attr)) || + dim_attr.getNumElements() != 1) + return unranked_ty; + int64_t dim_val = (*dim_attr.begin()).getSExtValue(); + int64_t input_rank = input_ty.getRank(); + + if (dim_val < -input_rank - 1 || dim_val > input_rank + 1) return unranked_ty; + if (dim_val < 0) dim_val += input_rank + 1; + + SmallVector shape = llvm::to_vector<4>(input_ty.getShape()); + shape.insert(shape.begin() + dim_val, 1); + return RankedTensorType::get(shape, element_ty); +} + +void ExpandDimsOp::build(OpBuilder &builder, OperationState &result, + Value input, Value dim) { + return build(builder, result, InferExpandDimsOpType(input, dim), input, dim); +} + +//===----------------------------------------------------------------------===// +// FakeQuantWithMinMaxArgsOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(FakeQuantWithMinMaxArgsOp op) { + // TODO(fengliuai): moving the following to an utility method. + const llvm::fltSemantics &semantics = op.min().getSemantics(); + float rmin, rmax; + if (&semantics == &APFloat::IEEEsingle()) { + rmin = op.min().convertToFloat(); + rmax = op.max().convertToFloat(); + } else { + rmin = op.min().convertToDouble(); + rmax = op.max().convertToDouble(); + } + // Range boundaries must be valid. + if (rmin >= rmax) { + return op.emitOpError("range is invalid: [" + Twine(std::to_string(rmin)) + + "," + Twine(std::to_string(rmax)) + "]"); + } + int64_t num_bits = op.num_bits().getSExtValue(); + if (num_bits < 2 || num_bits > 16) { + return op.emitOpError( + "requires num_bits to be between 2 and 16, inclusive"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// FakeQuantWithMinMaxVarsOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(FakeQuantWithMinMaxVarsOp op) { + auto min = GetRankedTensorTypeForOperand(op.min()); + if (min && !IsOfRankedFloatTensorType(min, 0)) + return op.emitOpError("requires min to be a 0d float tensor"); + + auto max = GetRankedTensorTypeForOperand(op.max()); + if (max && !IsOfRankedFloatTensorType(max, 0)) + return op.emitOpError("requires max to be a 0d float tensor"); + + int64_t num_bits = op.num_bits().getSExtValue(); + if (num_bits < 2 || num_bits > 16) { + return op.emitOpError( + "requires num_bits to be between 2 and 16, inclusive"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// FakeQuantWithMinMaxVarsPerChannelOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(FakeQuantWithMinMaxVarsPerChannelOp op) { + auto min = GetRankedTensorTypeForOperand(op.min()); + if (min && !IsOfRankedFloatTensorType(min, 1)) + return op.emitOpError("requires min to be a 1d float tensor"); + + auto max = GetRankedTensorTypeForOperand(op.max()); + if (max && !IsOfRankedFloatTensorType(max, 1)) + return op.emitOpError("requires max to be a 1d float tensor"); + + Value inputs = op.inputs(); + if (!HasRankAtLeast(inputs, 1)) + return op.emitError("requires inputs to be at least 1d float tensor"); + + int64_t num_bits = op.num_bits().getSExtValue(); + if (num_bits < 2 || num_bits > 16) { + return op.emitOpError( + "requires num_bits to be between 2 and 16, inclusive"); + } + + auto inputs_type = inputs.getType().dyn_cast(); + if (!inputs_type) return success(); + int depth = inputs_type.getDimSize(inputs_type.getRank() - 1); + if ((min && min.getDimSize(0) != depth) || + (max && max.getDimSize(0) != depth)) { + return op.emitOpError( + "requires min and max to have same size as last dimension of inputs"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// FillOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(FillOp op) { + if (!IsOfRankOrUnranked(op.dims(), 1)) + return op.emitOpError() << "requires dims to be a 1D tensor"; + if (!IsOfRankOrUnranked(op.value(), 0)) + return op.emitOpError() << "requires value to be a scalar"; + + return success(); +} + +static ShapedType InferFillOpType(Value dims, Value value) { + Type etype = value.getType().cast().getElementType(); + + DenseIntElementsAttr dims_attr; + if (!matchPattern(dims, m_Constant(&dims_attr))) { + return UnrankedTensorType::get(etype); + } + + llvm::SmallVector shape; + shape.reserve(dims_attr.getNumElements()); + for (const APInt dim : dims_attr.getValues()) { + shape.push_back(dim.getSExtValue()); + } + return RankedTensorType::get(shape, etype); +} + +void FillOp::build(OpBuilder &builder, OperationState &result, Value dims, + Value value) { + FillOp::build(builder, result, InferFillOpType(dims, value), dims, value); +} + +OpFoldResult FillOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "fill op has two operand"); + + auto type = getType().cast(); + // DenseElementsAttr that is used in this folder only supports int and float + // types. + // TODO(hinsu): Handle complex types once there is a attribute kind for + // complex. + if (!type.getElementType().isIntOrFloat()) return {}; + + auto value = operands[1].dyn_cast_or_null(); + if (!value) return {}; + + if (type.hasStaticShape()) + return DenseElementsAttr::get(type, value.getValue({})); + + auto dims = operands[0].dyn_cast_or_null(); + if (!dims) return {}; + + llvm::SmallVector shape; + shape.reserve(dims.getNumElements()); + for (const APInt dim : dims.getValues()) { + shape.push_back(dim.getSExtValue()); + } + type = RankedTensorType::get(shape, type.getElementType()); + + return DenseElementsAttr::get(type, value.getValue({})); +} + +//===----------------------------------------------------------------------===// +// FusedBatchNormGradOp +//===----------------------------------------------------------------------===// + +// TODO(b/150954845): Add benchmarks to verify that layout preference didn't +// change in the latest GPU generations. + +LogicalResult FusedBatchNormGradV3Op::UpdateDataFormat(StringRef data_format) { + return ::mlir::TF::UpdateDataFormat(data_format, this); +} + +StringRef FusedBatchNormGradV3Op::GetOptimalLayout( + const RuntimeDevices &devices) { + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(getOperation())) + return data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + auto x_ty = x().getType().cast(); + const bool is_f16 = x_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // For all other data types prefer NCHW. + return "NCHW"; +} + +//===----------------------------------------------------------------------===// +// FusedBatchNormOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(FusedBatchNormOp op) { + auto x = GetRankedTensorTypeForOperand(op.x()); + if (x && !IsOfRankedFloatTensorType(x, 4)) + return op.emitOpError("requires x to be a 4D float tensor"); + + auto scale = GetRankedTensorTypeForOperand(op.scale()); + if (scale && !IsOfRankedFloatTensorType(scale, 1)) + return op.emitOpError("requires scale to be a 1D float tensor"); + + auto offset = GetRankedTensorTypeForOperand(op.offset()); + if (offset && !IsOfRankedFloatTensorType(offset, 1)) + return op.emitOpError("requires offset to be a 1D float tensor"); + + auto mean = GetRankedTensorTypeForOperand(op.mean()); + if (mean && !IsOfRankedFloatTensorType(mean, 1)) + return op.emitOpError("requires mean to be a 1D float tensor"); + + auto variance = GetRankedTensorTypeForOperand(op.variance()); + if (variance && !IsOfRankedFloatTensorType(variance, 1)) + return op.emitOpError("requires variance to be a 1D float tensor"); + + // TODO(antiagainst): check attributes + + return success(); +} + +//===----------------------------------------------------------------------===// +// FusedBatchNormV2Op / FusedBatchNormV3Op +//===----------------------------------------------------------------------===// + +template +static LogicalResult InferenceFoldOperandsPermutation( + ArrayRef permutation, Op *op) { + // FusedBatchNorm in training mode is a layout sentitive operation, and should + // have already assigned an optimal data format. + if (op->is_training()) return failure(); + return ::mlir::TF::FoldOperandsPermutation(permutation, op); +} + +template +static StringRef GetOptimalLayout(const RuntimeDevices &devices, Op *op) { + // In inference mode FusedBatchNorm is not sensitive to data layout. + if (!op->is_training()) return op->data_format(); + + // Keep current data format if no GPUs are available or if explicit placement + // does not allow to use GPU for this operation. + if (!CanUseGpuDevice(devices) || !CanUseGpuDevice(op->getOperation())) + return op->data_format(); + + // For f16 data type on devices with Tensor Cores support NHWC data format + // is up to ~2x faster. + auto x_ty = op->x().getType().template cast(); + const bool is_f16 = x_ty.getElementType().isF16(); + if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; + + // For all other data types prefer NCHW. + return "NCHW"; +} + +LogicalResult FusedBatchNormV2Op::FoldOperandsPermutation( + ArrayRef permutation) { + return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this); +} + +LogicalResult FusedBatchNormV2Op::UpdateDataFormat(StringRef data_format) { + return ::mlir::TF::UpdateDataFormat(data_format, this); +} + +StringRef FusedBatchNormV2Op::GetOptimalLayout(const RuntimeDevices &devices) { + return ::mlir::TF::GetOptimalLayout(devices, this); +} + +LogicalResult FusedBatchNormV3Op::FoldOperandsPermutation( + ArrayRef permutation) { + return ::mlir::TF::InferenceFoldOperandsPermutation(permutation, this); +} + +LogicalResult FusedBatchNormV3Op::UpdateDataFormat(StringRef data_format) { + return ::mlir::TF::UpdateDataFormat(data_format, this); +} + +StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices &devices) { + return ::mlir::TF::GetOptimalLayout(devices, this); +} + +//===----------------------------------------------------------------------===// +// GatherV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(GatherV2Op op) { + int64_t batch_dims = op.batch_dims().getSExtValue(); + if (auto ty = op.indices().getType().dyn_cast()) { + int64_t rank = ty.getRank(); + if (batch_dims > rank || batch_dims < -rank) + return op.emitOpError() + << "batch_dims (" << batch_dims << ") must be in range [" << -rank + << ", " << rank + 1 << ")"; + if (batch_dims < 0) batch_dims += rank; + } + + if (!HasRankAtMost(op.axis(), 1)) + return op.emitOpError("requires axis to have rank at most 1"); + + DenseIntElementsAttr axis_attr; + if (matchPattern(op.axis(), m_Constant(&axis_attr))) { + int64_t axis = (*axis_attr.begin()).getSExtValue(); + if (auto ty = op.params().getType().dyn_cast()) { + int64_t rank = ty.getRank(); + if (axis >= rank || axis < -rank) + return op.emitOpError() << "axis (" << axis << ") must be in range [" + << -rank << ", " << rank << ")"; + if (axis < 0) axis += rank; + } + + if (batch_dims >= 0 && axis >= 0 && axis < batch_dims) { + return op.emitOpError() << "requires axis (" << axis + << ") to be greater than or equal to batch_dims (" + << batch_dims << ")"; + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// IfOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(IfOp op) { + auto then_fn = op.then_func(); + if (!then_fn) + return op.emitOpError("then_branch refers to an undefined function : ") + << op.then_branch(); + auto else_fn = op.else_func(); + if (!else_fn) + return op.emitOpError("else_branch refers to an undefined function : ") + << op.else_branch(); + auto then_fn_type = then_fn.getType(); + auto else_fn_type = else_fn.getType(); + + // Non-conditional operands starting with the second operand are passed to + // branches and should be pair-wise compatible with branches' inputs. + unsigned expected_num_inputs = op.getNumOperands() - 1; + if (then_fn_type.getNumInputs() != expected_num_inputs || + else_fn_type.getNumInputs() != expected_num_inputs) + return op.emitError("branches should have " + Twine(expected_num_inputs) + + " inputs"); + + for (unsigned i = 0; i < expected_num_inputs; ++i) { + auto operand_type = op.getOperand(i + 1).getType().cast(); + auto then_input_type = then_fn_type.getInput(i).cast(); + if (!AreCastCompatible({operand_type, then_input_type})) + return op.emitError( + llvm::formatv("then branch input type {0} is incompatible with " + "operand type {1} at index {2}", + then_input_type, operand_type, i)); + + auto else_input_type = else_fn_type.getInput(i).cast(); + if (!AreCastCompatible({operand_type, else_input_type})) + return op.emitError( + llvm::formatv("else branch input type {0} is incompatible with " + "operand type {1} at index {2}", + else_input_type, operand_type, i)); + + // If branches have incompatible input types that means that no tensor can + // serve as input to both the functions. Hence, the op is invalid. + if (!AreCastCompatible({then_input_type, else_input_type})) + return op.emitError(llvm::formatv( + "branches inputs have incompatible types {0} and {1} at index {2}", + then_input_type, else_input_type, i)); + } + + // Branches' results should be pair-wise compatible with the op results. + unsigned expected_num_results = op.getNumResults(); + if (then_fn_type.getNumResults() != expected_num_results || + else_fn_type.getNumResults() != expected_num_results) + return op.emitError("branches should have " + Twine(expected_num_results) + + " results"); + + for (unsigned i = 0; i < expected_num_results; ++i) { + auto result_type = op.getResult(i).getType().cast(); + auto then_result_type = then_fn_type.getResult(i).cast(); + if (!AreCastCompatible({then_result_type, result_type})) + return op.emitError( + llvm::formatv("then branch result type {0} is incompatible with op " + "result type {1} at index {2}", + then_result_type, result_type, i)); + + auto else_result_type = else_fn_type.getResult(i).cast(); + if (!AreCastCompatible({else_result_type, result_type})) + return op.emitError( + llvm::formatv("else branch result type {0} is incompatible with op " + "result type {1} at index {2}", + else_result_type, result_type, i)); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// IfOp canonicalization. +//===----------------------------------------------------------------------===// + +class FoldConstantIfOp : public OpRewritePattern { + public: + explicit FoldConstantIfOp(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(TF::IfOp op, + PatternRewriter &rewriter) const override; + + private: + template + struct CallOpType { + using CallOp = T; + }; +}; + +LogicalResult FoldConstantIfOp::matchAndRewrite( + TF::IfOp op, PatternRewriter &rewriter) const { + // Extract the constant cond value. + DenseIntElementsAttr cond_attr; + if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure(); + + // Cond value must be a scalar. + if (cond_attr.getNumElements() != 1) return failure(); + + // Select a branch function. + bool cond = cond_attr.getSplatValue().getValue(); + FlatSymbolRefAttr func = cond ? op.then_branchAttr() : op.else_branchAttr(); + + // Replace IfOp with PartitionedCallOp or StatefulPartitionedCallOp. + auto rewrite = [&](auto op_type) { + auto empty = rewriter.getStringAttr(""); + auto call_op = rewriter.create( + op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, + /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); + PropagateDeviceAndInternalAttrs(op.getOperation(), call_op); + rewriter.replaceOp(op, call_op.getResults()); + }; + + if (op.is_stateless()) + rewrite(CallOpType{}); + else + rewrite(CallOpType{}); + + return success(); +} + +void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert>(context); +} + +//===----------------------------------------------------------------------===// +// IfRegionOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(IfRegionOp op) { + if (failed(VerifyRegionResults(op, op.then_branch(), "then"))) + return failure(); + if (failed(VerifyRegionResults(op, op.else_branch(), "else"))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// InvertOp +//===----------------------------------------------------------------------===// + +void InvertOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// InvertPermutationOp +//===----------------------------------------------------------------------===// + +// Verifies that the input is 1D. +static LogicalResult Verify(InvertPermutationOp op) { + auto x_type = op.x().getType().cast(); + if (!x_type.hasRank()) return success(); + if (x_type.getShape().size() != 1) + return op.emitOpError() << "requires input x to be 1-dimensional"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// LeakyReluOp +//===----------------------------------------------------------------------===// + +OpFoldResult LeakyReluOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "leaky relu has one operand"); + + // leaky_relu(x, alpha: 1) -> x + if (alpha().convertToFloat() == 1.0f) return getOperand(); + + auto calculate = [&](FloatAttr arg) { + APFloat val = arg.getValue(); + if (val.isNegative()) val = alpha() * val; + return FloatAttr::get(arg.getType(), val); + }; + + if (auto arg = operands[0].dyn_cast_or_null()) { + return calculate(arg); + } else if (auto arg = operands[0].dyn_cast_or_null()) { + if (auto elementAttr = arg.getSplatValue().dyn_cast()) + return DenseElementsAttr::get(arg.getType(), calculate(elementAttr)); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// LogOp +//===----------------------------------------------------------------------===// + +void LogOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// LogicalNotOp +//===----------------------------------------------------------------------===// + +void LogicalNotOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// MatrixBandPartOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(MatrixBandPartOp op) { + if (!HasRankAtLeast(op.input(), 2)) { + return op.emitOpError() + << "requires `input` to have rank of at least 2, but found " + << op.input().getType(); + } + if (!IsOfRankOrUnranked(op.num_lower(), 0)) { + return op.emitOpError() + << "requires `num_lower` to have 0 dimensions, but found " + << op.num_lower().getType(); + } + if (!IsOfRankOrUnranked(op.num_upper(), 0)) { + return op.emitOpError() + << "requires `num_upper` to have 0 dimensions, but found " + << op.num_upper().getType(); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// MaxOp +//===----------------------------------------------------------------------===// + +void MaxOp::build(OpBuilder &builder, OperationState &result, Value input, + Value reduction_indices, BoolAttr keep_dims) { + Type out_ty = + InferReductionOpType(input, reduction_indices, keep_dims, &builder); + build(builder, result, out_ty, input, reduction_indices, keep_dims); +} + +//===----------------------------------------------------------------------===// +// MaxPoolOp +//===----------------------------------------------------------------------===// + +LogicalResult MaxPoolOp::FoldOperandsPermutation( + ArrayRef permutation) { + return ::mlir::TF::FoldOperandsPermutation( + permutation, this, {{"strides", strides()}, {"ksize", ksize()}}); +} + +//===----------------------------------------------------------------------===// +// MaxPoolGradOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(MaxPoolGradOp op) { + if (!IsOfRankOrUnranked(op.orig_input(), 4)) { + return op.emitOpError() << "requires orig_input to be rank 4"; + } + if (!IsOfRankOrUnranked(op.orig_output(), 4)) { + return op.emitOpError() << "requires orig_output to be rank 4"; + } + if (!IsOfRankOrUnranked(op.grad(), 4)) { + return op.emitOpError() << "requires grad to be rank 4"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// MeanOp +//===----------------------------------------------------------------------===// + +LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { + // Reduction indices must be defined by a constant operation. + auto reduction_op = + dyn_cast_or_null(reduction_indices().getDefiningOp()); + if (!reduction_op) return failure(); + + auto reductions_value = reduction_op.value().dyn_cast(); + if (!reductions_value) return failure(); + + // Prepare new reduction indices according to operand permutation. + SmallVector shuffled_reduction; + llvm::transform(reductions_value.getIntValues(), + std::back_inserter(shuffled_reduction), + [&](APInt idx) { return permutation[idx.getSExtValue()]; }); + + // Add constant operation with a new reduction indices. + OpBuilder builder(getOperation()); + auto type = mlir::RankedTensorType::get(shuffled_reduction.size(), + builder.getIntegerType(32)); + auto values = mlir::DenseIntElementsAttr::get(type, shuffled_reduction); + auto shuffled_reduction_op = builder.create(getLoc(), values); + + // Use new reduction indices. + setOperand(1, shuffled_reduction_op); + + return success(); +} + +//===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +OpFoldResult MulOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc.inc" + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h new file mode 100644 index 00000000000..19a927a23d7 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +namespace mlir { +namespace TF { + +class YieldOp; + +// TODO(b/131258166): TensorFlow's mutex.h defines a `mutex_lock` macro, whose +// purpose is to catch bug on `tensorflow::mutex_lock`. We don't use +// `tensorflow::mutex_lock` here but we have ops (`tf.MutexLock` and +// `tf.ConsumeMutexLock`) with getter methods named as `mutex_lock()`. Need to +// undefine here to avoid expanding the getter symbol as macro when including +// both mutex.h and this header file. +#undef mutex_lock + +#define GET_OP_FWD_DEFINES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h.inc" + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_A_M_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc new file mode 100644 index 00000000000..71f1560aa6c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc @@ -0,0 +1,600 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is a simple include file used to simplify the splitting of the +// tf_ops.cc file. The helpers in here should be refactored and moved to +// tf_verifiers or tf_ops. +// TODO(jpienaar): Remove this file post refactoring. + +// Propagates underscore and device attributes from src to dst. +// TODO(b/158769932): This should be a general feature instead post some policy +// discussion. +static void PropagateDeviceAndInternalAttrs(Operation *src, Operation *dst) { + auto device = mlir::Identifier::get("device", src->getContext()); + for (auto named_attr : src->getAttrs()) { + if (*named_attr.first.begin() == '_' || named_attr.first == device) + dst->setAttr(named_attr.first, named_attr.second); + } +} + +//===----------------------------------------------------------------------===// +// TF op helper functions +//===----------------------------------------------------------------------===// + +// Returns the RankedTensorType for the given operand. TensorFlow constant ops +// may have non-static shape because the shape is not propagated during constant +// folding. If the defining op for the given operand is a constant op, this +// routine uses the constant op's attribute to get the actual shape. +static RankedTensorType GetRankedTensorTypeForOperand(Value operand) { + DenseElementsAttr attr; + if (matchPattern(operand, m_Constant(&attr))) { + return attr.getType().dyn_cast(); + } + return operand.getType().dyn_cast(); +} + +// Returns true if the given `value` is of ranked float tensor type with the +// given `rank`. +static inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) { + return type && type.getRank() == rank && + type.getElementType().isa(); +} + +// Returns true if the given `value` has the specified rank or has unranked +// type. +static inline bool IsOfRankOrUnranked(Value value, int64_t rank) { + RankedTensorType type = GetRankedTensorTypeForOperand(value); + return !type || type.getRank() == rank; +} + +// Returns true if the given `value` has at least the specified rank or has +// unranked type. +static inline bool HasRankAtLeast(Value value, int64_t rank) { + RankedTensorType type = GetRankedTensorTypeForOperand(value); + return !type || type.getRank() >= rank; +} + +// Returns true if the given `value` has at most the specified rank or has +// unranked type. +static inline bool HasRankAtMost(Value value, int64_t rank) { + RankedTensorType type = GetRankedTensorTypeForOperand(value); + return !type || type.getRank() <= rank; +} + +static bool IsUnknownDimOrRank(int64_t dim_or_rank) { + return dim_or_rank == -1; +} + +// Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If +// `incompatible_shape_error` is true, reports error if `x` and `y` has +// incompatible shapes. Otherwise, returns a tensor type with unknown rank. +static Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, + Value y, BoolAttr incompatible_shape_error) { + auto result_type = + OpTrait::util::getBroadcastedType(x.getType(), y.getType()); + if (!result_type) { + if (incompatible_shape_error.getValue()) { + mlir::emitError(loc, "non-broadcastable operands"); + } else { + return UnrankedTensorType::get(builder->getI1Type()); + } + } + + auto ranked_type = result_type.dyn_cast(); + if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type()); + + return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type()); +} + +// Returns dimension index for the given TensorFlow axis that supports negative +// indexing. +static int64_t GetDimForAxis(int64_t axis, int64_t rank) { + return axis >= 0 ? axis : axis + rank; +} + +// Infers output type for reduction ops such as SumOp, MaxOp etc. +// TODO(b/e667204a): Move this logic to shape inference once it supports custom +// inference functions. +static Type InferReductionOpType(Value input, Value reduction_indices, + BoolAttr keep_dims, Builder *builder) { + Type input_ty = input.getType(); + Type element_ty = getElementTypeOrSelf(input_ty); + + // Output type is unranked if input type is not ranked. + auto ranked_ty = input_ty.dyn_cast(); + if (!ranked_ty) return UnrankedTensorType::get(element_ty); + int64_t rank = ranked_ty.getRank(); + + DenseIntElementsAttr indices; + if (!matchPattern(reduction_indices, m_Constant(&indices))) { + // Output type is unranked if reduction indices are not constant and reduced + // dimensions are not kept. + if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty); + + // Otherwise, output type has same rank as the input. + return RankedTensorType::get(SmallVector(rank, -1), element_ty); + } + + int64_t num_reduce_dim = 0; + llvm::SmallVector is_reduce_dim(rank, false); + for (const APInt &index : indices.getValues()) { + int64_t dim = GetDimForAxis(index.getSExtValue(), rank); + // Invalid input. + if (dim < 0 || dim >= rank) return UnrankedTensorType::get(element_ty); + + if (!is_reduce_dim[dim]) { + is_reduce_dim[dim] = true; + num_reduce_dim++; + } + } + + ArrayRef shape = ranked_ty.getShape(); + SmallVector out_shape; + out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim)); + for (int64_t i = 0; i < rank; ++i) { + if (!is_reduce_dim[i]) + out_shape.push_back(shape[i]); + else if (keep_dims.getValue()) + out_shape.push_back(1); + } + return RankedTensorType::get(out_shape, element_ty); +} + +// Verifies that the given types are cast compatible. If not, emits appropriate +// error for the given op. If mask_one_dim is set to true, then the types are +// allowed to have one mismatching dimension. Masking one of the dimensions is +// useful for ops like Concat that requires all ranked inputs to have the same +// rank and match dimension sizes for all but one of the dimensions. +static LogicalResult VerifyTypesCompatibility( + Operation::operand_type_range types, bool mask_one_dim, Operation *op) { + constexpr int64_t kUninitialized = -1; + int64_t common_rank = kUninitialized; + llvm::SmallVector common_dims; + int64_t dim_to_mask = kUninitialized; + + // Initialize common_rank with rank of the first ranked type and verify that + // following ranked types have the same rank. + // Similarly, initialize each of the dimensions with the first type that has + // the dimension size available and verify that all following types have the + // same size for the dimension. However, if mask_one_dim is true, note down + // the dimension index on the first mismatch and ignore dimension at that + // index in following types. + for (Type ty : types) { + RankedTensorType ranked_ty = ty.dyn_cast(); + if (!ranked_ty) continue; + + int64_t rank = ranked_ty.getRank(); + if (common_rank == kUninitialized) { + common_rank = rank; + common_dims.resize(common_rank, kUninitialized); + } else if (common_rank != rank) { + return op->emitError() + << "operand type " << ranked_ty + << " is not compatible with preceding operands; expected rank: " + << common_rank; + } + + for (int64_t i = 0, e = common_rank; i != e; i++) { + if (i == dim_to_mask) continue; + + int64_t dim = ranked_ty.getDimSize(i); + if (dim == kUninitialized) continue; + + int64_t &common_dim = common_dims[i]; + if (common_dim == kUninitialized) { + common_dim = dim; + } else if (common_dim != dim) { + // If mask_one_dim is true, do not emit an error if this is the only + // dimension with mismatches. Note down the dimension to mask it from + // the following types. + if (mask_one_dim && dim_to_mask == kUninitialized) { + dim_to_mask = i; + continue; + } + + return op->emitError() << "operand type " << ranked_ty + << " is not compatible with preceding operands; " + "expected dimension at index " + << i << ": " << common_dim; + } + } + } + return success(); +} + +// This is a helper for the Select to SelectV2 canonicalization. The `data` rank +// refers to the rank of `t`/`e` (these two inputs have equal rank; this is +// checked in the verifier). +// +// In most cases, the predicate for Select can be used directly as the predicate +// for SelectV2. However, there is one case that varies, which is when the +// predicate is a tensor and the data is multidimensional. In this case, Select +// op semantics dictate that the predicate tensor length must match the size of +// the first data dimension. This varies from normal broadcasting semantics +// (which are used in SelectV2), so we must reshape the tensor in this case to +// be compatible. +static Value ReshapeSelectPredIfNecessary(OpBuilder *builder, Location loc, + Value cond, int data_rank) { + auto cond_tensor = cond.getType().cast(); + // Reshape is only needed in the case that the cond rank is 1 (i.e. it is + // a vector) AND t/e rank is > 1. + if (cond_tensor.getRank() != 1 || data_rank <= 1) { + // No reshape necessary. Leave cond as it is. + return cond; + } + + // This is the case where a reshape is needed. We want to construct the + // shape [x,1,...1], where x is the value in the pred tensor and the + // length of the shape is equal to data_rank. + SmallVector shape(data_rank, 1); + shape[0] = cond_tensor.getShape().front(); + auto new_shape_type = + RankedTensorType::get({data_rank}, builder->getIntegerType(64)); + auto shape_attr = DenseIntElementsAttr::get(new_shape_type, shape); + auto new_shape = builder->create(loc, shape_attr); + return builder->create(loc, cond, new_shape); +} + +//===----------------------------------------------------------------------===// +// Helper functions detect device capabilities from RuntimeDevices. +//===----------------------------------------------------------------------===// + +namespace { +using DeviceNameUtils = ::tensorflow::DeviceNameUtils; +using ParsedName = ::tensorflow::DeviceNameUtils::ParsedName; + +bool IsGpuDevice(const DeviceNameUtils::ParsedName &device) { + return device.type == ::tensorflow::DEVICE_GPU; +} + +} // namespace + +// Returns true if at least one GPU device is available at runtime. +bool CanUseGpuDevice(const RuntimeDevices &devices) { + return llvm::any_of(devices.device_names(), IsGpuDevice); +} + +// Returns true if all of the GPUs available at runtime support TensorCores +// (NVIDIA compute capability >= 7.0). +bool CanUseTensorCores(const RuntimeDevices &devices) { + auto has_tensor_cores = [&](const DeviceNameUtils::ParsedName &device) { + auto md = devices.GetGpuDeviceMetadata(device); + return md ? md->cc_major().getInt() >= 7 : false; + }; + return llvm::all_of( + llvm::make_filter_range(devices.device_names(), IsGpuDevice), + has_tensor_cores); +} + +// Returns true if operation does not have explicit device placement that would +// prevent it from running on GPU device. +bool CanUseGpuDevice(Operation *op) { + auto device_attr = op->getAttrOfType("device"); + if (!device_attr || device_attr.getValue().empty()) return true; + + DeviceNameUtils::ParsedName device; + if (!DeviceNameUtils::ParseFullName(device_attr.getValue().str(), &device)) + return false; + + // We can't use GPU if operation explicitly placed on non-GPU device. + return !device.has_type || device.type == ::tensorflow::DEVICE_GPU; +} + +//===----------------------------------------------------------------------===// +// TF op helper functions to work with layout transformation. +//===----------------------------------------------------------------------===// + +SmallVector ReversePermutation(ArrayRef permutation) { + SmallVector reverse(permutation.size()); + for (size_t i = 0; i < permutation.size(); ++i) { + reverse[permutation[i]] = i; + } + return reverse; +} + +SmallVector GetDataFormatPermutation(StringRef from, StringRef to) { + if (from == "NHWC" && to == "NCHW") { + return {0, 3, 1, 2}; + } else if (from == "NCHW" && to == "NHWC") { + return {0, 2, 3, 1}; + } else { + return {}; + } +} + +// Shuffle elements in the `attr` according to the permutation. Optional +// `inner_size` allows to shuffle array attributes created from rank 2 tensors +// on outer dimension only. +ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef permutation, + int inner_size = 1) { + if (attr.size() == 0) return attr; + + assert(attr.size() % inner_size == 0); + assert(attr.size() / inner_size == permutation.size()); + + SmallVector values{attr.begin(), attr.end()}; + SmallVector shuffled(values.size()); + + for (size_t i = 0; i < permutation.size(); ++i) { + for (size_t j = 0; j < inner_size; ++j) { + shuffled[i * inner_size + j] = values[permutation[i] * inner_size + j]; + } + } + + return ArrayAttr::get(shuffled, attr.getContext()); +} + +// Shuffle ranked tensor dimensions according to the permutation. +Type ShuffleRankedTensorType(Type type, ArrayRef permutation) { + if (auto ranked_type = type.dyn_cast()) { + ArrayRef shape = ranked_type.getShape(); + assert(permutation.size() == shape.size()); + + SmallVector new_shape(permutation.size()); + for (size_t i = 0; i < permutation.size(); ++i) + new_shape[i] = shape[permutation[i]]; + + return RankedTensorType::get(new_shape, ranked_type.getElementType()); + } + + return type; +} + +static bool AreCancellablePermutations(DenseIntElementsAttr perm0, + DenseIntElementsAttr perm1) { + if (perm0.getNumElements() == 0 || perm1.getNumElements() == 0) return false; + if (perm0.getNumElements() != perm1.getNumElements()) return false; + + SmallVector perm0_values; + for (const auto &value : perm0.getIntValues()) + perm0_values.push_back(value.getSExtValue()); + + SmallVector perm1_values; + for (const auto &value : perm1.getIntValues()) + perm1_values.push_back(value.getSExtValue()); + + for (int i = 0; i < perm0_values.size(); ++i) { + if (perm0_values[perm1_values[i]] != i) return false; + } + + return true; +} + +// Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for +// layout sensitive operations that do not have any additional layout dependent +// attributes besides `data_format` string. +template +LogicalResult UpdateDataFormat(StringRef data_format, Op *op) { + auto perm = GetDataFormatPermutation(op->data_format(), data_format); + if (perm.empty()) return failure(); + + // Update data format attribute. + op->setAttr("data_format", StringAttr::get(data_format, op->getContext())); + + // Update types for all layout sensitive results. + auto layout_sensitive = cast(op->getOperation()); + for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) { + OpResult result = op->getOperation()->getResult(idx); + result.setType(ShuffleRankedTensorType(result.getType(), perm)); + } + + return success(); +} + +// Default implementation for folding operand transpose into the operation. +// See `FoldOperandsTransposeInterface::FoldOperandsPermutation`. +template +LogicalResult FoldOperandsPermutation( + ArrayRef permutation, Op *op, + ArrayRef> shuffle_attrs = {}) { + MLIRContext *context = op->template getParentOfType().getContext(); + + // We only support NHWC <-> NCHW permutations. + static constexpr std::array kNchwToNhwc = {0, 2, 3, 1}; + static constexpr std::array kNhwcToNchw = {0, 3, 1, 2}; + + // Operation data format after folding `permutation`. + StringRef target_data_format = [&]() -> StringRef { + if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) { + return "NCHW"; // cancel NCHW->NHWC operand permutation + } else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) { + return "NHWC"; // cancel NHWC->NCHW operand permutation + } else { + return ""; + } + }(); + if (target_data_format.empty()) return failure(); + + // To fold operand `permutation` into the `op` we need shuffle all layout + // dependent attributes and types with a reverse permutation, and change + // operation data format to `target_data_format`. + // + // Example: + // %1 = SomeOp(...) {data_format = NHWC} + // %2 = Transpose(%1) {permutation = NHWC->NCHW} + // %3 = Op(%2) {data_format = NCHW} + // + // To bypass %2 we have to change data format to shuffle data format from NCHW + // to NHWC, which is the reverse of operand permutation (function argument). + auto reverse_permutation = + GetDataFormatPermutation(op->data_format(), target_data_format); + if (reverse_permutation.empty()) return failure(); + + op->setAttr("data_format", StringAttr::get(target_data_format, context)); + + for (auto pair : shuffle_attrs) { + StringRef attr_name = pair.first; + ArrayAttr attr_value = pair.second; + op->setAttr(attr_name, ShuffleArrayAttr(attr_value, reverse_permutation)); + } + + auto fold = cast(op->getOperation()); + for (unsigned idx : fold.GetLayoutDependentResults()) { + OpResult result = op->getOperation()->getResult(idx); + result.setType( + ShuffleRankedTensorType(result.getType(), reverse_permutation)); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Rewrite Pattern for removing trivial Arithmetic op. +//===----------------------------------------------------------------------===// + +namespace { +// Fold Arithmetic Op if one of the operands is a constant known to be an +// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if +// known identity value is either lhs or rhs. +template < + typename OpT, + typename std::enable_if::value>::type * = nullptr> +OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, + ArrayRef operands) { + auto lhs_type = arithmetic_op.x().getType().template cast(); + auto rhs_type = arithmetic_op.y().getType().template cast(); + auto result_type = + arithmetic_op.getResult().getType().template cast(); + + // We can fold arithmetic operation only of we can prove that we will not + // accidentally hide a broadcasting error. + auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty, + ShapedType result_ty) -> bool { + // Scalar identity is broadcastable to any operand shape, we only need to + // check that operand has the same shape as a result. + bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0; + if (scalar_identity) return operand_ty == result_ty; + + // If identity is not a scalar, we must verify that all shapes are equal + // and statically known. + // + // TODO(ezhulenev): Fold if identity shape is statically know to be + // broadcastable to the operand shape. + return operand_ty == result_ty && identity_ty == result_ty && + result_ty.hasStaticShape(); + }; + + // Check that we have a constant operand on one side (candidate for identity). + const bool is_commutative = + (std::is_same::value || std::is_same::value); + auto lhs_attr = operands[0].dyn_cast_or_null(); + auto rhs_attr = operands[1].dyn_cast_or_null(); + if (!rhs_attr && !(is_commutative && lhs_attr)) return {}; + + // Mul and Div ops have identity value one while AddV2 and SubOp have identity + // value zero. + const int identity = + (std::is_same::value || std::is_same::value || + std::is_same::value) + ? 1 + : 0; + + Type element_ty = lhs_type.getElementType(); + Attribute identity_attr; + if (auto ty = element_ty.template dyn_cast()) { + identity_attr = FloatAttr::get(ty, static_cast(identity)); + } else if (auto ty = element_ty.template dyn_cast()) { + identity_attr = IntegerAttr::get(ty, static_cast(identity)); + } else { + return {}; + } + + // Fold: Op(Operand, Identity) -> Operand. + if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) { + if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr) + return arithmetic_op.x(); + } + + // Fold: Op(Identity, Operand) -> Operand for commutative operations. + if (lhs_attr && is_commutative && + is_valid_broadcasting(rhs_type, lhs_type, result_type)) { + if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr) + return arithmetic_op.y(); + } + + return {}; +} +} // namespace + +// Verifies an reduction op's `input` and reduction `dims`. +static LogicalResult VerifyReductionInputAndDims(Value input, Value dims, + Location loc) { + auto dims_type = dims.getType().dyn_cast(); + if (!dims_type) return success(); + if (dims_type.getRank() > 1) + return emitError(loc, "dimensions can only be 0D or 1D tensor"); + + auto input_type = input.getType().dyn_cast(); + if (!input_type) return success(); + int64_t rank = input_type.getRank(); + + DenseIntElementsAttr dims_attr; + if (!matchPattern(dims, m_Constant(&dims_attr))) return success(); + for (const auto &dim_pair : llvm::enumerate(dims_attr)) { + int64_t cur_dim = dim_pair.value().getSExtValue(); + if (cur_dim < -rank || cur_dim >= rank) + return emitError(loc) + << dim_pair.index() << "-th dimension should be in the range of [-" + << rank << ", " << rank << ")"; + } + + return success(); +} + +LogicalResult VerifyRegionResults(Operation *op, Region ®ion, + StringRef region_name) { + auto op_name = op->getName().getStringRef(); + // verify that op outputs match yield inputs + YieldOp yield = cast(region.front().getTerminator()); + unsigned expected_num_results = op->getNumResults(); + if (yield.getNumOperands() != expected_num_results) + return op->emitOpError() + << region_name + " should have same number (" << expected_num_results + << ") of results as " << op_name << " but has " + << yield.getNumOperands() << " results"; + + for (int idx : llvm::seq(0, expected_num_results)) { + auto op_result_type = op->getResult(idx).getType().cast(); + auto region_result_type = + yield.getOperand(idx).getType().cast(); + if (!AreCastCompatible({region_result_type, op_result_type})) + return op->emitError(llvm::formatv( + "{0} result type {1} is incompatible with {2} " + "result type {3} at index {4}", + region_name, region_result_type, op_name, op_result_type, idx)); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Function control flow canonicalization. +//===----------------------------------------------------------------------===// + +// Eliminate attributes that are not needed, but can get attached to Ops +// during import. +template +struct DropAttributes : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // Drop the "output_shapes" attribute. + LogicalResult matchAndRewrite(Op op, + PatternRewriter &rewriter) const override { + bool found = op.removeAttr("output_shapes") == + MutableDictionaryAttr::RemoveResult::Removed; + return success(found); + } +}; + diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc new file mode 100644 index 00000000000..ffedcb47f7e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -0,0 +1,2326 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace mlir { +namespace TF { + +namespace { +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" +} // namespace + +//===----------------------------------------------------------------------===// +// NegOp +//===----------------------------------------------------------------------===// + +void NegOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// NotEqualOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(NotEqualOp op) { + // If we allow inputs to have incompatible type, then nothing to do. + if (!op.incompatible_shape_error()) return success(); + + // Otherwise, check inputs are broadcastable. + return mlir::OpTrait::impl::verifyCompatibleOperandBroadcast( + op.getOperation()); +} + +void NotEqualOp::build(OpBuilder &builder, OperationState &result, Value x, + Value y, BoolAttr incompatible_shape_error) { + auto result_type = DeduceEqualCmpOpType(&builder, result.location, x, y, + incompatible_shape_error); + return build(builder, result, result_type, x, y, incompatible_shape_error); +} + +//===----------------------------------------------------------------------===// +// OneHotOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(OneHotOp op) { + int64_t axis = op.axis().getSExtValue(); + + auto indices_ty = op.indices().getType().dyn_cast(); + if (indices_ty && + !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) { + return op.emitOpError() + << "expected axis (" << axis << ") to be -1 or between [0, " + << indices_ty.getShape().size() << "]"; + } + + if (axis < -1) { + return op.emitOpError() << "expected axis (" << axis + << ") to be -1 or between [0, rank(indices()))"; + } + + if (!IsOfRankOrUnranked(op.depth(), 0)) { + return op.emitOpError() << "requires depth to be a scalar"; + } + if (!IsOfRankOrUnranked(op.on_value(), 0)) { + return op.emitOpError() << "requires on_value to be a scalar"; + } + if (!IsOfRankOrUnranked(op.off_value(), 0)) { + return op.emitOpError() << "requires off_value to be a scalar"; + } + + DenseIntElementsAttr depth_attr; + if (matchPattern(op.depth(), m_Constant(&depth_attr))) { + if (depth_attr.getType().getRank() != 0) + return op.emitOpError() << "requires depth to be a scalar"; + int64_t depth = depth_attr.getValue({}).getSExtValue(); + if (depth < 0) { + return op.emitOpError() << "depth must be non-negative, got: " << depth; + } + } + + return success(); +} + +static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value, + Value off_value, IntegerAttr axis) { + int64_t axis_val = axis.getInt(); + Type element_ty = on_value.getType().cast().getElementType(); + auto unranked_ty = UnrankedTensorType::get(element_ty); + if (axis_val < -1) return unranked_ty; + + auto indices_ty = indices.getType().dyn_cast(); + if (!indices_ty) return unranked_ty; + + auto shape = llvm::to_vector<2>(indices_ty.getShape()); + if (axis_val == -1) axis_val = shape.size(); + + int64_t depth_val = ShapedType::kDynamicSize; + DenseIntElementsAttr depth_attr; + if (matchPattern(depth, m_Constant(&depth_attr)) && + depth_attr.getNumElements() == 1) + depth_val = (*depth_attr.begin()).getSExtValue(); + shape.insert(shape.begin() + axis_val, depth_val); + return RankedTensorType::get(shape, element_ty); +} + +void OneHotOp::build(OpBuilder &builder, OperationState &result, Value indices, + Value depth, Value on_value, Value off_value, + IntegerAttr axis) { + build(builder, result, + InferOneHotOpType(indices, depth, on_value, off_value, axis), indices, + depth, on_value, off_value, axis); +} + +//===----------------------------------------------------------------------===// +// PackOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(PackOp op) { + // TODO(hinsu): Convert variadic length attributes to derived attributes. + Operation::operand_range values = op.values(); + + if (failed(VerifyTypesCompatibility(values, + /*mask_one_dim=*/false, + op.getOperation()))) { + return failure(); + } + + int64_t inputs_rank = -1; + for (Value value : values) { + if (auto ty = value.getType().dyn_cast()) { + // Exit early as input types are verified to be compatible so all ranked + // tensors have the same rank. + inputs_rank = ty.getRank(); + break; + } + } + if (inputs_rank == -1) return success(); + + // The values can be packed along any of the dimensions between 0 and + // inputs rank, inclusive. Also, as the negative axis values wrap around so + // the axis value range is [-(R+1), R+1). + int64_t range_begin = -inputs_rank - 1; // Inclusive + int64_t range_end = inputs_rank + 1; // Exclusive + int64_t axis = op.axis().getSExtValue(); + if (axis < range_begin || axis >= range_end) { + return op.emitError() << "attribute 'axis' should be within range [" + << range_begin << ", " << range_end + << "); actual value: " << axis; + } + + return success(); +} + +OpFoldResult PackOp::fold(ArrayRef operands) { + // Fold pack operation if it computes the input tensor shape: + // + // %shape = tf.Shape(%arg) // [? x ...] + // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value + // %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...] + // + // Where `...` are some statically known dimensions. In this case %pack can be + // replaced with a %shape. This is a common pattern in models with a dynamic + // batch size. + + // Pack operation should pack at least two values. + if (values().size() < 2) return {}; + + // Dimensions packed along axis = 0 (pack scalars into vector). + if (axis().getSExtValue() != 0) return {}; + + // First packed value is defined by a strided slice operation. + auto slice_op = dyn_cast_or_null(values()[0].getDefiningOp()); + if (!slice_op) return {}; + + // Input to the slice op is defined by shape operation. + auto shape_op = dyn_cast_or_null(slice_op.input().getDefiningOp()); + if (!shape_op) return {}; + + // Input tensor, which shape is reconstructed by the pack operation. + Value tensor = shape_op.input(); + + // All masks are `0` except `shrink_axis_mask` which is equal to `1` (slicing + // scalar value from input vector). + if (slice_op.begin_mask().getSExtValue() != 0 || + slice_op.ellipsis_mask().getSExtValue() != 0 || + slice_op.end_mask().getSExtValue() != 0 || + slice_op.new_axis_mask().getSExtValue() != 0 || + slice_op.shrink_axis_mask().getSExtValue() != 1) + return {}; + + // Returns a value if the `value` is defined by a ConstOp with a single + // integer element in it and has an expected rank. + auto get_const_int = [](Value value, int expected_rank) -> Optional { + auto const_op = dyn_cast_or_null(value.getDefiningOp()); + if (!const_op) return None; + + auto value_attr = const_op.value().dyn_cast(); + if (!value_attr || value_attr.getNumElements() != 1) return None; + + auto value_ty = value_attr.getType(); + if (!value_ty.hasRank() || value_ty.getRank() != expected_rank) return None; + + auto splat = value_attr.getSplatValue(); + return splat.getValue().getSExtValue(); + }; + + // All other packed values are scalar constants. + SmallVector packed_dims; + packed_dims.reserve(values().size() - 1); + for (Value operand : llvm::drop_begin(values(), 1)) { + if (auto dim = get_const_int(operand, /*expected_rank=*/0)) { + packed_dims.push_back(*dim); + } else { + return {}; + } + } + + // Slice exactly the first shape dimension: + // begin = [0] end = [1], strides = [1] + auto begin = get_const_int(slice_op.begin(), /*expected_rank=*/1); + auto end = get_const_int(slice_op.end(), /*expected_rank=*/1); + auto strides = get_const_int(slice_op.strides(), /*expected_rank=*/1); + if (!begin.hasValue() || !end.hasValue() || !strides.hasValue() || + *begin != 0 || *end != 1 || *strides != 1) + return {}; + + // First tensor dimension is dynamic. + auto arg_ty = tensor.getType().dyn_cast(); + if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 || + !arg_ty.isDynamicDim(0)) + return {}; + + // Argument tensor rank is equal to the number of packed dimensions. + if (arg_ty.getRank() != values().size()) return {}; + + // All other dimensions are statically known and equal to packed dims. + auto arg_dims = llvm::drop_begin(arg_ty.getShape(), 1); + if (!std::equal(arg_dims.begin(), arg_dims.end(), packed_dims.begin())) + return {}; + + // Replace %pack with %shape. + return slice_op.input(); +} + +//===----------------------------------------------------------------------===// +// PadOp +//===----------------------------------------------------------------------===// + +LogicalResult PadOp::FoldOperandsPermutation(ArrayRef permutation) { + // Paddings must be defined by a constant operation. + auto paddings_op = dyn_cast_or_null(paddings().getDefiningOp()); + if (!paddings_op) return failure(); + + auto paddings_value = paddings_op.value().dyn_cast(); + if (!paddings_value || + paddings_value.getNumElements() != permutation.size() * 2) + return failure(); + + SmallVector shuffled_paddings(paddings_value.getNumElements()); + for (auto index_pair : llvm::enumerate(paddings_value.getIntValues())) { + size_t outer_idx = index_pair.index() / 2; + size_t inner_idx = index_pair.index() % 2; + + shuffled_paddings[permutation[outer_idx] * 2 + inner_idx] = + index_pair.value().getSExtValue(); + } + + // Add constant operation with a new paddings. + OpBuilder builder(getOperation()); + auto type = mlir::RankedTensorType::get(paddings_value.getType().getShape(), + builder.getIntegerType(32)); + auto values = mlir::DenseIntElementsAttr::get(type, shuffled_paddings); + auto shuffled_paddings_op = builder.create(getLoc(), values); + + // Use new paddings. + setOperand(1, shuffled_paddings_op); + + // Change the result type. + getResult().setType(ShuffleRankedTensorType(getResult().getType(), + ReversePermutation(permutation))); + + return success(); +} + +//===----------------------------------------------------------------------===// +// ParseExampleV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ParseExampleV2Op op) { + // NOTE(mrry): This validates properties of an op that would previously be + // validated by the TensorFlow OpDef type checker. In addition to these + // checks, the shape inference function for ParseExampleV2 validates the + // consistency of the argument and result types. + + // Validate dense variadic input and output lengths. + // NOTE(mrry): The Tdense attr is derived from dense_defaults, so we + // do not need to validate dense_defaults. + auto dense_types_count = + std::distance(op.Tdense().begin(), op.Tdense().end()); + auto dense_values_count = + std::distance(op.dense_values().begin(), op.dense_values().end()); + if (dense_values_count != dense_types_count) { + return op.emitError() << "output 'dense_values' should have same length " + << "as attribute 'Tdense'"; + } + + // Validate sparse variadic output lengths. + // NOTE(mrry): The sparse_types attr is derived from sparse_values, so we + // do not need to validate sparse_values. + auto sparse_types_count = + std::distance(op.sparse_types().begin(), op.sparse_types().end()); + if (op.num_sparse() != sparse_types_count) { + return op.emitError() << "attribute 'num_sparse' should be the same as " + << "the length of attribute 'sparse_types'"; + } + if (op.sparse_indices().size() != sparse_types_count) { + return op.emitError() << "output 'sparse_indices' should have same length " + << "as attribute 'sparse_types'"; + } + if (op.sparse_shapes().size() != sparse_types_count) { + return op.emitError() << "output 'sparse_shapes' should have same length " + << "as attribute 'sparse_types'"; + } + + // Validate ragged variadic output lengths. + auto ragged_value_types_count = std::distance(op.ragged_value_types().begin(), + op.ragged_value_types().end()); + auto ragged_split_types_count = std::distance(op.ragged_split_types().begin(), + op.ragged_split_types().end()); + if (ragged_value_types_count != ragged_split_types_count) { + return op.emitError() << "attribute 'ragged_value_types' should have same " + << "length as attribute 'ragged_split_types'"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// PartitionedCallOp +//===----------------------------------------------------------------------===// + +template +static LogicalResult VerifyPartitionedCall(OpClass op) { + auto module = op.template getParentOfType(); + SymbolRefAttr func = op.getAttr("f").template cast(); + + auto function = + dyn_cast_or_null(SymbolTable::lookupSymbolIn(module, func)); + + if (!function) { + return op.emitError("'f' attribute refers to an undefined function: ") + << func; + } + + FunctionType function_ty = function.getType(); + int func_arg_count = function_ty.getNumInputs(); + int arg_count = op.args().size(); + + if (arg_count != func_arg_count) { + return op.emitError() << "argument count mismatch: 'args' has " << arg_count + << " arguments, but '" << func << "' expects " + << func_arg_count; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// PowOp +//===----------------------------------------------------------------------===// + +OpFoldResult PowOp::fold(ArrayRef operands) { + auto constant_y = operands[1].dyn_cast_or_null(); + if (constant_y && constant_y.isSplat()) { + APFloat y_value = constant_y.getSplatValue(); + auto output_type = getType().cast(); + if (y_value.isZero() && output_type.hasStaticShape()) { + return DenseElementsAttr::get( + output_type, + FloatAttr::get(output_type.getElementType(), /*value=*/1.0)); + } + if (y_value.isExactlyValue(1.0)) { + return x(); + } + } + return {}; +} + +//===----------------------------------------------------------------------===// +// QrOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// * Input type, if ranked, must have at least 2 dimensions and at most +// INT32_MAX dimensions. +// +static LogicalResult Verify(QrOp op) { + auto ttype = op.input().getType().cast(); + if (!ttype.hasRank()) return success(); + if (!HasRankAtLeast(op.input(), 2)) + return op.emitOpError( + "requires ranked input tensor to be of rank 2 or more"); + if (!HasRankAtMost(op.input(), std::numeric_limits::max())) + return op.emitOpError( + "requires ranked input tensor to be of rank INT32_MAX or less"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// ReadVariableOp +//===----------------------------------------------------------------------===// + +void ReadVariableOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// ReciprocalOp +//===----------------------------------------------------------------------===// + +void ReciprocalOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// RandomUniformOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(RandomUniformOp op) { + if (!IsOfRankOrUnranked(op.shape(), 1)) + return op.emitOpError("shape must be 1D tensor"); + return success(); +} + +//===----------------------------------------------------------------------===// +// RangeOp +//===----------------------------------------------------------------------===// + +void RangeOp::build(OpBuilder &builder, OperationState &result, Value start, + Value limit, Value delta) { + assert(start.getType() == limit.getType()); + assert(start.getType() == delta.getType()); + DenseIntElementsAttr start_val; + DenseIntElementsAttr limit_val; + DenseIntElementsAttr delta_val; + if (matchPattern(start, m_Constant(&start_val)) && + matchPattern(limit, m_Constant(&limit_val)) && + matchPattern(delta, m_Constant(&delta_val))) { + auto size = llvm::APIntOps::RoundingSDiv( + *limit_val.begin() - *start_val.begin(), *delta_val.begin(), + llvm::APInt::Rounding::DOWN); + return RangeOp::build( + builder, result, + RankedTensorType::get( + size.getSExtValue(), + start.getType().cast().getElementType()), + start, limit, delta); + } + return RangeOp::build( + builder, result, + RankedTensorType::get( + {-1}, start.getType().cast().getElementType()), + start, limit, delta); +} +//===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +void RankOp::build(OpBuilder &builder, OperationState &result, Value input) { + return RankOp::build(builder, result, + RankedTensorType::get({}, builder.getIntegerType(32)), + input); +} + +// This will create a constant value for RankOp of a ranked tensor. +OpFoldResult RankOp::fold(ArrayRef operands) { + auto type = input().getType(); + auto ranked_type = type.dyn_cast(); + if (!ranked_type) return {}; + + auto output_type = getType().cast(); + int32_t rank = ranked_type.getRank(); + return DenseIntElementsAttr::get(output_type, rank); +} + +//===----------------------------------------------------------------------===// +// RealDivOp +//===----------------------------------------------------------------------===// + +void RealDivOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult RealDivOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +// TODO(b/128020684): Verify the output type. +static LogicalResult Verify(ReshapeOp op) { + auto shape_type = op.shape().getType().cast(); + if (!shape_type.hasRank()) return success(); + if (shape_type.getRank() != 1) + return op.emitOpError("shape must be 1D tensor"); + auto rank_by_shape = shape_type.getShape()[0]; + auto type_of_tensor = op.tensor().getType().cast(); + // No compile time verification for unknown sized shape. + if (rank_by_shape == -1 || !type_of_tensor.hasStaticShape()) return success(); + int64_t num_by_tensor = type_of_tensor.getNumElements(); + + auto out_ty = op.getType().dyn_cast(); + if (out_ty && out_ty.hasStaticShape()) { + int64_t num_output_elements = out_ty.getNumElements(); + if (num_by_tensor != num_output_elements) + return op.emitOpError() + << "number of output elements (" << num_output_elements + << ") does not match expected number of elements (" + << num_by_tensor << ")"; + } + + // Check values if constant shape. No compiling time verification for + // non-constant shape. + auto *shape_op = op.shape().getDefiningOp(); + if (!shape_op) return success(); + Attribute shape_cst; + if (!matchPattern(shape_op, m_Constant(&shape_cst))) return success(); + auto shape_cst_attr = shape_cst.dyn_cast(); + if (!shape_cst_attr) return op.emitOpError("shape must be a valid tensor"); + + if (auto opaque_attr = shape_cst_attr.dyn_cast()) { + opaque_attr.decode(shape_cst_attr); + } + + // We know the shape is a 1-D Tensor, then let us get the number of + // elements it implies. + unsigned num_by_shape = 1; + unsigned unknown_dim_count = 0; + for (int i = 0, e = rank_by_shape; i != e; ++i) { + auto num = shape_cst_attr.getValue(i).getInt(); + // The dimension size value can be -1, and that the real size needs to + // be computed so that the total size remains constant. At most one + // component of shape can be -1. + if (num == -1) { + if (++unknown_dim_count > 1) { + return op.emitOpError("more than one component of shape are -1"); + } + } else { + num_by_shape *= num; + } + } + // If there is one component of shape is -1, the dimension should be + // computed so that the total size remains constant. + if (unknown_dim_count == 1) { + if (num_by_tensor % num_by_shape != 0) + return op.emitOpError( + "one component of shape is -1 but couldn't infer the dimension"); + return success(); + } + // If the elements by the tensor and implies by the shape don't match, + // fail this static check. + if (num_by_tensor != num_by_shape) { + return op.emitOpError( + "mismatch in tensor elements and shape implied elements"); + } + return success(); +} + +void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, + Value shape) { + auto ttype = tensor.getType().cast(); + auto etype = ttype.getElementType(); + + auto unranked = [&builder, etype, &result, shape, tensor]() { + return ReshapeOp::build(builder, result, UnrankedTensorType::get(etype), + tensor, shape); + }; + + // If tensor is unranked then we have no info about output of shape. + if (!ttype.hasRank()) return unranked(); + + DenseIntElementsAttr attr_shape; + if (matchPattern(shape, m_Constant(&attr_shape))) { + llvm::SmallVector const_shape; + const_shape.reserve(attr_shape.getNumElements()); + + // Detect if reshape output shape is folded. + bool flatten = false; + int unknown_index = -1; + // The product of constant shape argument excluding unknown dimension. + int64_t product_cshape = 1; + for (auto e : llvm::enumerate(attr_shape)) { + int64_t val = e.value().getSExtValue(); + if (IsUnknownDimOrRank(val)) { + if (flatten) { + mlir::emitError(result.location) + << "only one unknown dimension allowed"; + return; + } + flatten = true; + unknown_index = e.index(); + } else { + product_cshape *= val; + } + const_shape.push_back(val); + } + + // Compute the value of the unknown dimension. + if (flatten) { + // Compute number of elements in tensor shape. + auto tshape = ttype.getShape(); + int64_t product_tshape = std::accumulate(tshape.begin(), tshape.end(), 1, + std::multiplies()); + // Set the unknown dimension such that total number of elements remain + // constant. + // Note: The case where the ratio is not integral, and so the total size + // of reshape not constant, is checked in verify function. + const_shape[unknown_index] = product_tshape / product_cshape; + } + return ReshapeOp::build(builder, result, + RankedTensorType::get(const_shape, etype), tensor, + shape); + } + return unranked(); +} + +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult ReshapeOp::fold(ArrayRef operands) { + Value tensor = this->tensor(); + + // Fold reshape if operand and result types are the same and all dimensions + // are statically known (no-op reshape). + // TODO(ezhulenev): Add the same folding for BroadcastToOp. + auto result_ty = getType().dyn_cast(); + if (result_ty && result_ty.hasStaticShape() && + result_ty == tensor.getType()) { + return tensor; + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +void SelectOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +// Verifies a few extra requirements on SelectOp: +// (1) `then` and `else` must have same shape +// (2) At least one of the following must be true: +// (a) `cond` has the same rank as `then` and `else` +// (b) `cond` is a scalar +// (c) `cond` is a vector AND `then` and `else` are non-scalar with their +// first dimension equal to `cond`. +static LogicalResult Verify(SelectOp op) { + auto then_tensor = op.t().getType().cast(); + auto else_tensor = op.e().getType().cast(); + // Check (1). + if (!AreCastCompatible({then_tensor, else_tensor})) + return op.emitOpError() << "requires t and e have compatible shapes"; + + // Get data rank (if exists). + int data_rank; + // If data is unranked or data_rank is 0, this will remain -2. Otherwise + // refers to first dimension of then and/or else. + int data_first_dim = -2; + bool then_has_rank = then_tensor.hasRank(); + bool else_has_rank = else_tensor.hasRank(); + if (then_has_rank && else_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + if (else_tensor.getRank() > 0) + data_first_dim = std::max( + static_cast(else_tensor.getShape().front()), data_first_dim); + } else if (then_has_rank) { + data_rank = then_tensor.getRank(); + if (then_tensor.getRank() > 0) + data_first_dim = then_tensor.getShape().front(); + } else if (else_has_rank) { + data_rank = else_tensor.getRank(); + if (else_tensor.getRank() > 0) + data_first_dim = else_tensor.getShape().front(); + } else { + // Neither has a rank. + return success(); + } + + auto cond_tensor = op.condition().getType().dyn_cast(); + if (!cond_tensor) return success(); + auto cond_rank = cond_tensor.getRank(); + // Check (2a) and (2b). + if (cond_rank == 0 || cond_rank == data_rank) return success(); + // Check (2c). + if (cond_rank == 1) { + auto cond_shape = cond_tensor.getShape().front(); + if (data_rank == 0) { + return op.emitOpError() + << "requires that t and e are nonscalar when pred is a vector"; + } + // We know `data` tensor has a rank of at least 1. + if (data_first_dim != -1 && cond_shape != -1 && + data_first_dim != cond_shape) { + return op.emitOpError() << "requires that, when pred is a vector, the " + "shape matches the first dimension of t and e"; + } + return success(); + } + // None of (2a,b,c) were true; fail. + return op.emitOpError() << "requires that pred is a scalar OR has the same " + "rank as t and e OR is a vector"; +} + +//===----------------------------------------------------------------------===// +// SelectV2Op +//===----------------------------------------------------------------------===// + +static Type InferSelectV2OpType(Value condition, Value e, Value t) { + Type element_ty = e.getType().cast().getElementType(); + auto unranked_ty = UnrankedTensorType::get(element_ty); + + Type broadcasted_ty = + OpTrait::util::getBroadcastedType(e.getType(), t.getType()); + if (!broadcasted_ty) return unranked_ty; + + auto cond_ranked_ty = condition.getType().dyn_cast(); + auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast(); + if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty; + + // Explicitly get broadcasted output type as element types of condition may + // not be same as the broadcated type's element type. + SmallVector result_shape; + if (!OpTrait::util::getBroadcastedShape(cond_ranked_ty.getShape(), + broadcasted_ranked_ty.getShape(), + result_shape)) + return unranked_ty; + return RankedTensorType::get(result_shape, element_ty); +} + +void SelectV2Op::build(OpBuilder &builder, OperationState &result, + Value condition, Value e, Value t) { + build(builder, result, InferSelectV2OpType(condition, e, t), condition, e, t); +} + +//===----------------------------------------------------------------------===// +// ShapeOp +//===----------------------------------------------------------------------===// + +namespace { +// Validates Shape/ShapeN/VariableShape operand and associated result types. +LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, + Type result_type, + int variadic_idx = -1) { + std::string variadic_idx_str = + variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str(); + + auto result_ranked_type = result_type.dyn_cast(); + if (!result_ranked_type) return success(); + if (result_ranked_type.getShape().size() != 1) + return op->emitOpError("requires 1D type for result") << variadic_idx_str; + + auto operand_ranked_type = operand_type.dyn_cast_or_null(); + if (operand_ranked_type) { + // The operand is a ranked tensor. + if (result_ranked_type.hasStaticShape() && + !operand_ranked_type.getShape().empty() && + result_ranked_type.getDimSize(0) != + operand_ranked_type.getShape().size()) + return op->emitOpError("requires dimension size of result") + << variadic_idx_str << " to match rank of operand" + << variadic_idx_str; + } else if (result_ranked_type.hasStaticShape()) { + // The operand is an unranked tensor, print a warning if the result + // is static. + // Note: We do not handle this situation as an error, this would be too + // restrictive due to incompleteness of shape inference at this point. + op->emitWarning("has static shape result") + << variadic_idx_str << " for unranked operand" << variadic_idx_str; + } + + Type element_type = result_ranked_type.getElementType(); + if (!element_type.isSignlessInteger(32) && + !element_type.isSignlessInteger(64)) + return op->emitOpError("requires int32 or int64 return type for result") + << variadic_idx_str; + + return success(); +} +} // anonymous namespace + +static LogicalResult Verify(ShapeOp op) { + return VerifyShapeOperandAndResult(op, op.input().getType(), op.getType()); +} + +// Converts shape of the given type to attribute if it is of ranked tensor type. +// Returned attribute has integer elements of the given width. +static Attribute ConvertShapeToAttr(Type input_ty, int out_width) { + auto ranked_ty = input_ty.dyn_cast(); + if (!ranked_ty || !ranked_ty.hasStaticShape()) return {}; + + auto shape = ranked_ty.getShape(); + int rank = shape.size(); + + SmallVector dimensions; + dimensions.reserve(rank); + for (int i = 0; i < rank; ++i) + dimensions.push_back(APInt(out_width, shape[i])); + + auto result_type = RankedTensorType::get( + {rank}, IntegerType::get(out_width, input_ty.getContext())); + return DenseElementsAttr::get(result_type, dimensions); +} + +OpFoldResult ShapeOp::fold(ArrayRef operands) { + int width = + getType().cast().getElementType().getIntOrFloatBitWidth(); + return ConvertShapeToAttr(getOperand().getType(), width); +} + +void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input, + BoolAttr use32Bit) { + auto rankedTensorType = input.getType().dyn_cast(); + int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1; + auto out_type = use32Bit.getValue() ? builder.getIntegerType(32) + : builder.getIntegerType(64); + return ShapeOp::build(builder, result, + RankedTensorType::get({rank}, out_type), input); +} + +//===----------------------------------------------------------------------===// +// ShapeNOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ShapeNOp op) { + const size_t num_tensors = op.N(); + + if (op.getNumOperands() != num_tensors) + return op.emitOpError() << "requires " << num_tensors << " operand(s), got " + << op.getNumOperands() << " operand(s)"; + + if (op.getNumResults() != num_tensors) + return op.emitOpError() << "requires " << num_tensors << " result(s), got " + << op.getNumResults() << " result(s)"; + + for (auto i : llvm::seq(0, num_tensors)) { + auto verification = VerifyShapeOperandAndResult( + op, op.getOperand(i).getType(), op.getResult(i).getType(), i); + if (failed(verification)) return verification; + } + + return success(); +} + +LogicalResult ShapeNOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + if (getNumOperands() == 0) return success(); + int width = + getType(0).cast().getElementType().getIntOrFloatBitWidth(); + + for (Type input_ty : getOperandTypes()) { + OpFoldResult result = ConvertShapeToAttr(input_ty, width); + if (!result) return failure(); + + results.push_back(result); + } + return success(); +} + +// TODO(hinsu): Add canonicalization pattern for ShapeN ops that don't have all +// static input shapes. Replacing output values corresponding to static input +// types may enable optimizations in users of the values. + +//===----------------------------------------------------------------------===// +// SizeOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// * Input type, if is a ranked tensor, has at most INT32_MAX dimensions. +// +static LogicalResult Verify(SizeOp op) { + if (!HasRankAtMost(op.input(), std::numeric_limits::max())) + return op.emitOpError( + "requires ranked input tensor to be of rank INT32_MAX or less"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// SliceOp +//===----------------------------------------------------------------------===// + +// Verifies that: +// +// - operands begin and size are 1D with the same number of elements. +// - if the input is a ranked tensor, the rank of the input equals the number +// of elements in operands begin and size. +// - if begin are constants, that +// 0 <= begin[i] <= begin[i] + size[i] <= input_ty.getShape()[i] +// - if begins aren't constant but the input is a ranked tensor, that +// size[i] <= input_ty.getShape()[i] +// +static LogicalResult Verify(SliceOp op) { + RankedTensorType begin_ty = GetRankedTensorTypeForOperand(op.begin()); + if (begin_ty && begin_ty.getRank() != 1) { + return op.emitOpError() << "requires begin operand to be 1D tensor"; + } + + RankedTensorType size_ty = GetRankedTensorTypeForOperand(op.size()); + if (size_ty && size_ty.getRank() != 1) { + return op.emitOpError() << "requires size operand to be 1D tensor"; + } + + if (!begin_ty || !size_ty || !begin_ty.hasStaticShape() || + !size_ty.hasStaticShape()) + return success(); + + if (begin_ty.getNumElements() != size_ty.getNumElements()) { + return op.emitOpError() << "requires begin and size operands to have the" + " same number of elements"; + } + + auto input_ty = op.input().getType().dyn_cast(); + if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) { + return op.emitOpError() << "requires number of elements in begin and size" + "are equal to input rank"; + } + + DenseIntElementsAttr begin_indices; + if (matchPattern(op.begin(), m_Constant(&begin_indices))) { + DenseIntElementsAttr slice_sizes; + bool constant_slice_sizes = + matchPattern(op.size(), m_Constant(&slice_sizes)); + int dim = 0; + for (const APInt &raw_begin_index : begin_indices.getValues()) { + int64_t begin_index = raw_begin_index.getSExtValue(); + int64_t input_size = input_ty ? input_ty.getShape()[dim] : -1; + int64_t slice_size = constant_slice_sizes + ? slice_sizes.getValue(dim).getSExtValue() + : 0; + if (slice_size == -1 && input_size != -1) { + slice_size = input_size - begin_index; + } + if (begin_index < 0 || + (input_size != -1 && begin_index + slice_size > input_size)) { + return op.emitOpError() + << "requires 0 <= begin[i] <= begin[i] + size[i] <= Di"; + } + ++dim; + } + } else if (input_ty) { + // If the inputs are ranked, we can do a few more sanity checks. + DenseIntElementsAttr slice_sizes; + if (matchPattern(op.size(), m_Constant(&slice_sizes))) { + auto input_shape = input_ty.getShape(); + for (int64_t i = 0; i < input_ty.getRank(); ++i) { + int64_t slice_size = slice_sizes.getValue(i).getInt(); + int64_t input_size = input_shape[i]; + if (slice_size != -1 && input_size != -1 && slice_size > input_size) { + return op.emitOpError() << "requires size[i] <= Di, even if begin[i] " + "is unknown at compile time"; + } + } + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// SoftmaxOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SoftmaxOp op) { + if (!HasRankAtLeast(op.logits(), 1)) { + return op.emitOpError("requires operand to have rank at least 1"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// SoftmaxCrossEntropyWithLogitsOp +//===----------------------------------------------------------------------===// + +// Verifies that, +// +// * Input types are broadcast compatible and the broadcasted type has rank two. +// +static LogicalResult Verify(SoftmaxCrossEntropyWithLogitsOp op) { + auto broadcasted_ty = OpTrait::util::getBroadcastedType( + op.features().getType(), op.labels().getType()) + .dyn_cast_or_null(); + if (!broadcasted_ty || + (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2)) + return op.emitOpError( + "requires features and labels to be broadcast compatible to rank two"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// SparseSoftmaxCrossEntropyWithLogitsOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SparseSoftmaxCrossEntropyWithLogitsOp op) { + if (!IsOfRankOrUnranked(op.features(), 2)) { + return op.emitOpError("requires features operand of rank two"); + } + if (!IsOfRankOrUnranked(op.labels(), 1)) { + return op.emitOpError("requires labels operand of rank one"); + } + auto features_ty = op.features().getType().dyn_cast(); + auto labels_ty = op.labels().getType().dyn_cast(); + if (features_ty && labels_ty) { + int64_t features_batches = features_ty.getDimSize(0); + int64_t labels_batches = labels_ty.getDimSize(0); + if (!ShapedType::isDynamic(features_batches) && + !ShapedType::isDynamic(labels_batches) && + features_batches != labels_batches) + return op.emitOpError( + "requires features and labels with matching first dimension"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// SplitOp +//===----------------------------------------------------------------------===// + +// Verifies the input and split dimension operands for tf.Split/tf.SplitV. +// Writes the split dimension's index (adjusted with input rank) via `dim_index` +// if it's a constant. +template +LogicalResult VerifySplitInputAndSplitDim(Op op, Optional *dim_index) { + *dim_index = llvm::None; + + Value split_dim = op.split_dim(); + if (auto split_dim_type = split_dim.getType().dyn_cast()) + if (split_dim_type.getRank() != 0) + return op.emitOpError( + "split dimension should be an integer scalar tensor"); + + // We can perform further verification if the input tensor to be split has + // known rank and the split dimension tensor is a constant. + + auto input_type = op.value().getType().template dyn_cast(); + if (!input_type) return success(); + + int64_t input_rank = input_type.getRank(); + if (input_rank == 0) + return op.emitOpError("cannot split scalar input tensor"); + + DenseIntElementsAttr split_dim_attr; + if (!matchPattern(split_dim, m_Constant(&split_dim_attr))) return success(); + + int64_t index = (*split_dim_attr.begin()).getSExtValue(); + + if (index + input_rank < 0 || index >= input_rank) { + return op.emitOpError("split dimension must be in range [-") + << input_rank << ", " << input_rank << ")"; + } + + if (index < 0) index += input_rank; + *dim_index = index; + + return success(); +} + +static LogicalResult Verify(SplitOp op) { + Optional dim_index; + if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); + if (!dim_index) return success(); + + int64_t input_dim_size = + op.value().getType().cast().getDimSize(*dim_index); + if (input_dim_size == ShapedType::kDynamicSize) return success(); + + if (input_dim_size % op.getNumResults() != 0) + return op.emitOpError("dimension #") + << *dim_index << " not divisible by the number of result tensors"; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SplitVOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SplitVOp op) { + auto split_sizes_type = + op.size_splits().getType().dyn_cast(); + if (!split_sizes_type) return success(); + + if (split_sizes_type.getRank() != 1 || + split_sizes_type.getDimSize(0) != op.getNumResults()) + return op.emitOpError("split sizes should be a 1D tensor of ") + << op.getNumResults() << " elements"; + + Optional dim_index = 0; + if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); + if (!dim_index) return success(); + + int64_t input_dim_size = + op.value().getType().cast().getDimSize(*dim_index); + if (input_dim_size == ShapedType::kDynamicSize) return success(); + + // If split sizes come from a constant, they must sum to the dimension size + // along split_dim, and we can have no more than one dynamic dimension. + DenseIntElementsAttr split_sizes_attr; + if (!matchPattern(op.size_splits(), m_Constant(&split_sizes_attr))) + return success(); + + int64_t total_dim_size = 0; // Total dimension size assigned to splits + llvm::Optional dynamic_dim_index; + + SmallVector split_sizes; + split_sizes.reserve( + split_sizes_attr.getType().cast().getNumElements()); + + for (auto dim : llvm::enumerate(split_sizes_attr)) { + int64_t dim_val = dim.value().getSExtValue(); + split_sizes.push_back(dim_val); + if (dim_val == ShapedType::kDynamicSize) { + // We cannot have more than one dynamic dimension. + if (dynamic_dim_index) + return op.emitOpError( + "cannot have more than one dynamic dimension in split sizes"); + dynamic_dim_index = dim.index(); + } else { + total_dim_size += dim_val; + } + } + + if (!dynamic_dim_index && total_dim_size != input_dim_size) + return op.emitOpError( + "split sizes must sum up to the dimension size along split " + "dimension, found ") + << total_dim_size << " vs " << input_dim_size; + + if (dynamic_dim_index && total_dim_size > input_dim_size) + return op.emitOpError( + "split sizes must sum up to be less than or equal to the " + "dimension size along split dimension, found ") + << total_dim_size << " vs " << input_dim_size; + + return success(); +} + +//===----------------------------------------------------------------------===// +// SquareOp +//===----------------------------------------------------------------------===// + +void SquareOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// SubOp +//===----------------------------------------------------------------------===// + +void SubOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult SubOp::fold(ArrayRef operands) { + return IdentityArithmeticOpFolder(*this, operands); +} + +//===----------------------------------------------------------------------===// +// SumOp +//===----------------------------------------------------------------------===// + +void SumOp::build(OpBuilder &builder, OperationState &result, Value input, + Value reduction_indices, BoolAttr keep_dims) { + Type out_ty = + InferReductionOpType(input, reduction_indices, keep_dims, &builder); + build(builder, result, out_ty, input, reduction_indices, keep_dims); +} + +//===----------------------------------------------------------------------===// +// StridedSliceOp +//===----------------------------------------------------------------------===// + +// TODO(b/154160827): Add a canonicalization pattern from tf.StridedSliceOp to +// tf.SliceOp if both of the following are true: +// - All strides have a known value equal to 1 +// - No masks are set (or masks can be applied by transforming the inputs to +// Slice) + +// Verifies that, +// +// - begin, end and strides operands are 1D and they have the same number of +// elements. Here, the number of elements should be less than 32 to support +// 32-bit mask attributes. +// - None of the strides values are zero. +// - Ellipsis mask can have at most one bit set. + +template +static LogicalResult VerifyStridedSliceBase(OpTy op) { + // Expected size for operands begin, end and strides vector operands. + int64_t expected_size = -1; + + for (Value val : {op.begin(), op.end(), op.strides()}) { + auto operand_ty = val.getType().dyn_cast(); + if (!operand_ty || !operand_ty.hasStaticShape()) { + // TensorFlow constant ops may have non-static shape because the shape is + // not propagated during constant folding. If the defining op for this + // operand is a constant op, use the constant op's attribute to get the + // actual shape. + DenseIntElementsAttr attr; + if (!matchPattern(val, m_Constant(&attr))) continue; + operand_ty = attr.getType(); + } + + if (operand_ty.getRank() != 1) + return op.emitOpError() + << "requires begin, end and strides to be 1D tensors"; + + int64_t length = operand_ty.getDimSize(0); + if (length == -1) continue; + + if (expected_size == -1) { + // This op uses 32-bit masks. + if (length >= 32) + return op.emitOpError( + "requires begin, end and strides operands with less than 32 " + "elements"); + + expected_size = length; + } else if (length != expected_size) { + return op.emitOpError() << "requires begin, end and strides to have the " + "same number of elements"; + } + } + + // If strides are constants, verify that none of the element is zero. + DenseIntElementsAttr strides; + if (matchPattern(op.strides(), m_Constant(&strides))) { + if (llvm::is_contained(strides.getValues(), 0)) + return op.emitOpError("requires non-zero strides"); + } + + // Use bit compares to ensure ellipsis_mask is 0 or a power of 2, i.e. there + // exists only no more than one ellipsis. + uint32_t ellipsis_mask = op.ellipsis_mask().getZExtValue(); + if (ellipsis_mask != 0 && !llvm::isPowerOf2_32(ellipsis_mask)) + return op.emitOpError("cannot have multiple ellipses"); + + return success(); +} + +// Clamps the given `val`: returns `low` if `val` is less than `low`; returns +// `high` if `high` is less than `val`; otherwise returns `val`. +template +constexpr const T &Clamp(const T &val, const T &low, const T &high) { + assert(!(high < low)); + return (val < low) ? low : (high < val) ? high : val; +} + +// Checks if the `index` bit of `val` is set. +template +constexpr bool IsSet(const T &val, unsigned index) { + return (val & (1 << index)) != 0; +} + +// Sets the `index` bit of `val`. +template +constexpr void Set(T &val, unsigned index) { + val |= (1 << index); +} + +// Unset the `index` bit of `val`. +template +constexpr void Unset(T &val, unsigned index) { + val &= ~(1 << index); +} + +// Copy the `src_index` bit of `src` to `dst_index` bit of `dst`. +template +constexpr void CopyBit(const T &src, unsigned src_index, T &dst, + unsigned dst_index) { + if (IsSet(src, src_index)) + Set(dst, dst_index); + else + Unset(dst, dst_index); +} + +// The sparse spec of strided slice does not correspond to the number of +// dimensions. For example, sparse spec for foo[..., 3:10] for foo of shape (2, +// 4, 8) would have dims = 2. +struct SparseSliceSpec { + int64_t dims; + int32_t begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask; + const ArrayRef &begin; + const ArrayRef &end; + const ArrayRef &strides; +}; + +// The dense spec of strided slice is the canonicalized version of sparse spec. +// The number of dimensions of dense spec correspond to the number of dimensions +// in operand tensor. +struct DenseSliceSpec { + int64_t dims; + int32_t begin_mask, end_mask, shrink_axis_mask; + SmallVectorImpl &begin; + SmallVectorImpl &end; + SmallVectorImpl &strides; +}; + +// Make a sparse spec into a dense index spec. +// The sparse spec does not correspond to the number of dimensions +// Make a dense spec that corresponds to the number of dimensions +// +// For example suppose foo[...,3:, 2] on foo.shape=(2,2,3,4) then +// we need to produce the missing begin_mask, end_mask for the first two +// dimensions i.e. foo[:, :, 3:, 2]. +static void BuildDenseSliceSpec(const SparseSliceSpec &sparse, + DenseSliceSpec *dense) { + // Build expanded dense begin, end, strides, begin_mask, end_mask, and + // shrink_axis_mask. + dense->begin.resize(dense->dims); + dense->end.resize(dense->dims); + dense->strides.resize(dense->dims); + dense->begin_mask = 0; + dense->end_mask = 0; + dense->shrink_axis_mask = 0; + + // Count number of new_axis after ellipsis. This helps in calculating the + // number of dimensions ellipsis represents in the sparse spec. + bool ellipsis_seen = false; + int num_new_axis_after_ellipsis = 0; + for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) { + if (ellipsis_seen && IsSet(sparse.new_axis_mask, sparse_index)) + num_new_axis_after_ellipsis++; + if (IsSet(sparse.ellipsis_mask, sparse_index)) ellipsis_seen = true; + } + + int dense_index = 0; + for (int sparse_index = 0; sparse_index < sparse.dims; ++sparse_index) { + if (IsSet(sparse.new_axis_mask, sparse_index)) continue; + if (IsSet(sparse.ellipsis_mask, sparse_index)) { + auto next_index = std::min(dense->dims - (sparse.dims - sparse_index) + + 1 + num_new_axis_after_ellipsis, + dense->dims); + // Expand ellipsis into the appropriate dense indices. From current index + // until next_index, all dimensions would have begin and end masks set and + // stride 1, i.e., get all elements in those dimensions. + for (; dense_index < next_index; ++dense_index) { + dense->begin[dense_index] = dense->end[dense_index] = 0; + dense->strides[dense_index] = 1; + Set(dense->begin_mask, dense_index); + Set(dense->end_mask, dense_index); + } + continue; + } + assert(dense_index < dense->dims); + // Copy over the sparse indices to dense indices if ellipsis_mask and + // new_axis_mask are not set. + dense->begin[dense_index] = sparse.begin[sparse_index]; + dense->end[dense_index] = sparse.end[sparse_index]; + dense->strides[dense_index] = sparse.strides[sparse_index]; + CopyBit(sparse.begin_mask, sparse_index, dense->begin_mask, dense_index); + CopyBit(sparse.end_mask, sparse_index, dense->end_mask, dense_index); + CopyBit(sparse.shrink_axis_mask, sparse_index, dense->shrink_axis_mask, + dense_index); + dense_index++; + } +} + +// For the given `input_shape`, calculates the sliced shape using the given +// `begin`, `end`, and `stride` ranges and `begin_mask`, `end_mask`, and +// `shrink_axis_mask` masks. Updates the result back to `input_shape`. If +// `shrink_axis_mask` is not zero, this function will not drop the corresponding +// dimensions in `input_shape`; it will turn them into 1s. At the same time, +// canonicalizes `begin`, `end`, and `strides. The calculation follows +// tf.StridedSlice op semantics. +static void CalculateSlicedShapeFromDenseIndices( + MutableArrayRef input_shape, int32_t begin_mask, int32_t end_mask, + int32_t shrink_axis_mask, MutableArrayRef begin, + MutableArrayRef end, MutableArrayRef stride) { + assert(input_shape.size() <= 32); // Only 32-bit masks are supported. + + // Make sure ranges' ranks are consistent with the input. + assert(input_shape.size() == begin.size()); + assert(input_shape.size() == end.size()); + assert(input_shape.size() == stride.size()); + + for (int i = 0, e = input_shape.size(); i < e; ++i) { + if (ShapedType::isDynamic(input_shape[i])) continue; + + int64_t dim_i = input_shape[i]; + int64_t begin_i = begin[i]; + int64_t end_i = end[i]; + int64_t stride_i = stride[i]; + + // [0]: mask for begin, [1]: mask for end + int64_t masks[] = {begin_mask & (1 << i), end_mask & (1 << i)}; + // [0]: bound for begin, [1]: bound for end + int64_t bounds[] = {stride_i > 0 ? 0 : -1, + stride_i > 0 ? dim_i : dim_i - 1}; + + // Canonicalizes the given range `point` (begin/end) according to the + // current dimension. `c` means case: 0 for begin, 1 for end. + auto canonicalize = [&](int64_t point, int c) { + if (masks[c]) return stride_i > 0 ? bounds[c] : bounds[(c + 1) & 1]; + + // Add dim as offset to negative range point. + point = point < 0 ? dim_i + point : point; + return Clamp(point, bounds[0], bounds[1]); + }; + + begin_i = canonicalize(begin_i, 0); + end_i = canonicalize(end_i, 1); + + int64_t interval_len = end_i - begin_i; + int64_t size_i = 0; + // If internal length is zero or has different sign from stride, it's a + // degenerated case: we are slicing nothing. Otherwise, calculate the sliced + // size. + if (interval_len != 0 && (interval_len < 0) == (stride_i < 0)) + size_i = (interval_len / stride_i) + (interval_len % stride_i != 0); + + begin[i] = begin_i; + if (IsSet(shrink_axis_mask, i)) { + // Shrink this dimension. It means we only take the element at begin_i. + input_shape[i] = 1; + end[i] = begin_i + 1; + stride[i] = 1; + } else { + input_shape[i] = size_i; + end[i] = end_i; + stride[i] = stride_i; + } + } +} + +// For the given `input_shape`, calculates the sliced shape using the given +// `sparse_begin`, `sparse_end`, and `sparse_strides` ranges and `begin_mask`, +// `end_mask`, `ellipsis_mask` , `new_axis_mask` and `shrink_axis_mask` masks. +// Updates the result back to `input_shape`. +static void CalculateSlicedShapeFromSparseIndices( + MutableArrayRef input_shape, ArrayRef sparse_begin, + ArrayRef sparse_end, ArrayRef sparse_strides, + int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask, + int32_t new_axis_mask, int32_t shrink_axis_mask, + SmallVectorImpl *begin, SmallVectorImpl *end, + SmallVectorImpl *stride) { + int64_t num_sparse_indices = sparse_begin.size(); + SparseSliceSpec sparse = {num_sparse_indices, begin_mask, end_mask, + ellipsis_mask, new_axis_mask, shrink_axis_mask, + sparse_begin, sparse_end, sparse_strides}; + + // If no ellipsis_mask exists then an implicit ellipsis_mask at the end is + // inserted. This handles cases where foo[2:4] (foo.shape() = [4, 8]) yields + // a tensor of shape [2, 8], i.e., foo[2:4] is same as foo[2:4, ...]. + if (sparse.ellipsis_mask == 0) { + Set(sparse.ellipsis_mask, sparse.dims); + sparse.dims++; + } + + int64_t dims = input_shape.size(); + DenseSliceSpec dense = {dims, + /*begin_mask = */ 0, + /*end_mask = */ 0, + /*shrink_axis_mask = */ 0, + *begin, + *end, + *stride}; + + BuildDenseSliceSpec(sparse, &dense); + CalculateSlicedShapeFromDenseIndices(input_shape, dense.begin_mask, + dense.end_mask, dense.shrink_axis_mask, + *begin, *end, *stride); +} + +bool StridedSliceOp::GetSlicedBoundRanges( + SmallVectorImpl *slice_begin, SmallVectorImpl *slice_end, + SmallVectorImpl *slice_stride) { + // TODO(hinsu): Support lowering for ops with dynamic begin and end values + // when it is possible to derive indices based on mask attributes. + DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; + if (!matchPattern(begin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(end(), m_Constant(&sparse_end_attr)) || + !matchPattern(strides(), m_Constant(&sparse_strides_attr))) + return false; + + auto input_ty = this->input().getType().dyn_cast(); + if (!input_ty || !input_ty.hasStaticShape()) return false; + auto input_shape = llvm::to_vector<4>(input_ty.getShape()); + + SmallVector sparse_begin, sparse_end, sparse_strides; + + for (const APInt &index : sparse_begin_attr) + sparse_begin.push_back(index.getSExtValue()); + for (const APInt &index : sparse_end_attr) + sparse_end.push_back(index.getSExtValue()); + for (const APInt &stride : sparse_strides_attr) + sparse_strides.push_back(stride.getSExtValue()); + + CalculateSlicedShapeFromSparseIndices( + input_shape, sparse_begin, sparse_end, sparse_strides, + begin_mask().getZExtValue(), end_mask().getZExtValue(), + ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), + shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); + return true; +} + +//===----------------------------------------------------------------------===// +// StridedSliceGradOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(StridedSliceGradOp op) { + auto shape_type = op.shape().getType().dyn_cast(); + if (shape_type && shape_type.getRank() != 1) + return op.emitOpError("'shape' operand must be 1D tensor, but got ") + << shape_type.getRank() << "D tensor"; + + if (failed(VerifyStridedSliceBase(op))) return failure(); + + // TODO(antiagainst): verify the gradient op.dy()'s shape is consistent with + // the sliced type from StridedSlice. + + return success(); +} + +bool StridedSliceGradOp::GetSlicedShapeAndBoundRanges( + SmallVectorImpl *input_shape, + SmallVectorImpl *slice_begin, SmallVectorImpl *slice_end, + SmallVectorImpl *slice_stride) { + DenseIntElementsAttr shape_attr; + DenseIntElementsAttr sparse_begin_attr, sparse_end_attr, sparse_strides_attr; + if (!matchPattern(shape(), m_Constant(&shape_attr)) || + !matchPattern(begin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(end(), m_Constant(&sparse_end_attr)) || + !matchPattern(strides(), m_Constant(&sparse_strides_attr))) + return false; + + int rank = std::distance(shape_attr.begin(), shape_attr.end()); + + input_shape->clear(); + input_shape->reserve(rank); + for (const APInt &dim : shape_attr) + input_shape->push_back(dim.getSExtValue()); + + SmallVector sparse_begin, sparse_end, sparse_strides; + + for (const APInt &index : sparse_begin_attr) + sparse_begin.push_back(index.getSExtValue()); + for (const APInt &index : sparse_end_attr) + sparse_end.push_back(index.getSExtValue()); + for (const APInt &stride : sparse_strides_attr) + sparse_strides.push_back(stride.getSExtValue()); + + CalculateSlicedShapeFromSparseIndices( + *input_shape, sparse_begin, sparse_end, sparse_strides, + begin_mask().getZExtValue(), end_mask().getZExtValue(), + ellipsis_mask().getZExtValue(), new_axis_mask().getZExtValue(), + shrink_axis_mask().getZExtValue(), slice_begin, slice_end, slice_stride); + return true; +} + +//===----------------------------------------------------------------------===// +// TensorListReserveOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TensorListReserveOp op) { + if (!IsOfRankOrUnranked(op.element_shape(), 0) && + !IsOfRankOrUnranked(op.element_shape(), 1)) { + return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); + } + + if (!IsOfRankOrUnranked(op.num_elements(), 0)) { + return op.emitOpError("requires num_elements operand to be 0D tensor"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// TensorListElementShapeOp +//===----------------------------------------------------------------------===// + +OpFoldResult TensorListElementShapeOp::fold(ArrayRef operands) { + int width = + getType().cast().getElementType().getIntOrFloatBitWidth(); + auto variant_type = + getElementTypeOrSelf(getOperand().getType()).cast(); + if (variant_type.getSubtypes().empty()) return {}; + return ConvertShapeToAttr(variant_type.getSubtypes()[0], width); +} + +//===----------------------------------------------------------------------===// +// TensorListStackOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TensorListStackOp op) { + if (!IsOfRankOrUnranked(op.element_shape(), 0) && + !IsOfRankOrUnranked(op.element_shape(), 1)) { + return op.emitOpError("requires element_shape operand to be 0D/1D tensor"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// TensorScatterUpdateOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TensorScatterUpdateOp op) { + if (!HasRankAtLeast(op.tensor(), 1)) + return op.emitOpError( + "requires tensor operand to have at least 1 dimension"); + if (!HasRankAtLeast(op.indices(), 1)) + return op.emitOpError( + "requires indices operand to have at least 1 dimension"); + if (!HasRankAtLeast(op.updates(), 1)) + return op.emitOpError( + "requires updates operand to have at least 1 dimension"); + + auto tensor_ty = op.tensor().getType().dyn_cast(); + auto indices_ty = op.indices().getType().dyn_cast(); + if (!tensor_ty || !indices_ty) return success(); + + int64_t num_index_dims = indices_ty.getShape().back(); + if (ShapedType::isDynamic(num_index_dims)) return success(); + + if (num_index_dims > tensor_ty.getRank()) + return op.emitOpError( + "requires tensor operand with rank greater than or equal to the " + "indices operand's last dimensions"); + return success(); +} + +//===----------------------------------------------------------------------===// +// TopKV2Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TopKV2Op op) { + if (!HasRankAtLeast(op.input(), 1)) + return op.emitOpError( + "requires input operand to have at least 1 dimension"); + + if (!IsOfRankOrUnranked(op.k(), 0)) + return op.emitOpError("requires k operand to be 0D tensor"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// ToBoolOp +//===----------------------------------------------------------------------===// + +namespace { +// If the input to ToBoolOp is a `tensor`, then the ToBoolOp is an identity +// function and can be removed. +class ToBoolOfZeroDBoolTensor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ToBoolOp op, + PatternRewriter &rewriter) const override { + if (auto type = op.getOperand().getType().dyn_cast()) { + if (type.getRank() == 0 && type.getElementType().isInteger(1)) { + rewriter.replaceOp(op, op.getOperand()); + return success(); + } + } + return failure(); + } +}; +} // namespace + +void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TransposeOp op) { + auto perm_type = op.perm().getType().dyn_cast(); + auto x_type = op.x().getType().dyn_cast(); + auto y_type = op.y().getType().dyn_cast(); + + if (perm_type && perm_type.getRank() != 1) { + return op.emitOpError() + << "expected perm to be a 1-D Tensor, got perm of rank " + << perm_type.getRank(); + } + + if (x_type && y_type && x_type.getRank() != y_type.getRank()) { + return op.emitOpError() << "x should be of the same rank with y, got " + << "x of rank " << x_type.getRank() + << ", and y of rank " << y_type.getRank(); + } + + if (!x_type || !y_type || !perm_type || !perm_type.hasStaticShape()) { + return success(); + } + + if (x_type.getRank() != perm_type.getNumElements()) { + return op.emitOpError() << "expected perm to be a 1-D Tensor of size " + << "equal to the rank of x, got perm of size " + << perm_type.getNumElements() << ", and x of rank " + << x_type.getRank(); + } + + DenseIntElementsAttr attr_perm; + if (matchPattern(op.perm(), m_Constant(&attr_perm))) { + // y.shape[i] should be equal to x.shape[perm[i]] + // for i = [0, 1, ..., rank(x) - 1] + for (auto e : llvm::enumerate(attr_perm)) { + const int64_t y_idx = e.index(); + const int64_t y_dim = y_type.getDimSize(y_idx); + const int64_t x_idx = e.value().getSExtValue(); + const int64_t x_dim = x_type.getDimSize(x_idx); + if (y_dim != ShapedType::kDynamicSize && + x_dim != ShapedType::kDynamicSize && y_dim != x_dim) { + return op.emitOpError() + << "requires y.shape[" << y_idx << "] (" << y_dim << ") " + << "to be equal to x.shape[perm[" << x_idx << "]] " + << "(" << x_dim << ")"; + } + } + } + + return success(); +} + +// TODO(jpienaar): perm could be optional too. +void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, + Value perm) { + auto x_type = x.getType().cast(); + // If value is unranked, then so is results. + if (!x_type.hasRank()) + return TransposeOp::build(builder, result, + UnrankedTensorType::get(x_type.getElementType()), + x, perm); + + // TODO(jpienaar): Handle unknown perm case. + + // TODO(jpienaar): Extract utility function. + auto etype = x_type.cast().getElementType(); + DenseIntElementsAttr attr_shape; + if (matchPattern(perm, m_Constant(&attr_shape))) { + llvm::SmallVector const_shape; + if (attr_shape.isSplat()) { + const_shape.assign( + attr_shape.getNumElements(), + x_type.getDimSize((*attr_shape.begin()).getSExtValue())); + } else { + const_shape.reserve(attr_shape.getNumElements()); + for (const auto &dim : attr_shape) + const_shape.push_back(x_type.getDimSize(dim.getSExtValue())); + } + return TransposeOp::build( + builder, result, RankedTensorType::get(const_shape, etype), x, perm); + } + return TransposeOp::build(builder, result, UnrankedTensorType::get(etype), x, + perm); +} + +namespace { + +OpFoldResult FoldIdentityTranspose(TransposeOp op) { + auto const_perm = dyn_cast_or_null(op.perm().getDefiningOp()); + if (!const_perm) return {}; + + auto const_value = const_perm.value(); + const auto elements = const_value.getValues(); + + for (auto it : llvm::enumerate(elements)) { + if (it.index() != it.value()) return {}; + } + + // TODO(jpienaar): Remove if/when we handle this more generally. + if (op.getType() != op.x().getType()) { + // If the types don't match then only fold if all the operands are in the TF + // dialect. + for (auto user : op.getOperation()->getUsers()) + if (user->getDialect() != op.getDialect()) return {}; + } + + return op.x(); +} + +OpFoldResult FoldCancellableTranspose(TransposeOp op) { + // Operand is a TransposeOp. + auto transpose = dyn_cast_or_null(op.x().getDefiningOp()); + if (!transpose) return {}; + + // Permutations defined by constant operations. + auto perm0 = dyn_cast_or_null(op.perm().getDefiningOp()); + auto perm1 = dyn_cast_or_null(transpose.perm().getDefiningOp()); + if (!perm0 || !perm1) return {}; + + // With permutation indices that cancel each other + auto perm0_value = perm0.value().cast(); + auto perm1_value = perm1.value().cast(); + if (!AreCancellablePermutations(perm0_value, perm1_value)) return {}; + + return transpose.x(); +} + +} // namespace + +OpFoldResult TransposeOp::fold(ArrayRef operands) { + if (auto folded = FoldIdentityTranspose(*this)) return folded; + if (auto folded = FoldCancellableTranspose(*this)) return folded; + return {}; +} + +//===----------------------------------------------------------------------===// +// TruncateDivOp +//===----------------------------------------------------------------------===// + +void TruncateDivOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// UnpackOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(UnpackOp op) { + auto value_type = op.value().getType().dyn_cast(); + if (!value_type) return success(); + + int64_t value_rank = value_type.getRank(); + int64_t axis = op.axis().getSExtValue(); + if (axis < -value_rank || axis >= value_rank) + return op.emitOpError("axis attribute must be in the range of [-") + << value_rank << ", " << value_rank << ')'; + + axis = GetDimForAxis(axis, value_rank); + int64_t dim_size = value_type.getDimSize(axis); + if (ShapedType::isDynamic(dim_size)) return success(); + + if (dim_size != op.getNumResults()) + return op.emitOpError("result count must be equal to ") << dim_size; + + return success(); +} + +//===----------------------------------------------------------------------===// +// Unsorted segment reduction ops +//===----------------------------------------------------------------------===// + +template +static LogicalResult VerifyUnsortedSegmentReduction(Op op) { + if (!HasRankAtMost(op.num_segments(), 0)) + return op.emitOpError("number of segments should be a 0-D tensor"); + + auto data_type = op.data().getType().template dyn_cast(); + auto segment_ids_type = + op.segment_ids().getType().template dyn_cast(); + if (data_type && segment_ids_type) { + if (data_type.getRank() < segment_ids_type.getRank()) + return op.emitOpError( + "requires segment ids rank to be less than or equal to data's rank"); + + int index = 0; + for (auto shape_pair : + llvm::zip_first(segment_ids_type.getShape(), data_type.getShape())) { + int64_t segment_id_dim = std::get<0>(shape_pair); + int64_t data_dim = std::get<1>(shape_pair); + if (!ShapedType::isDynamic(segment_id_dim) && + !ShapedType::isDynamic(data_dim) && segment_id_dim != data_dim) + return op.emitOpError( + "requires segment ids shape to be a prefix of data shape, " + "but dimension #") + << index << " differs: " << segment_id_dim << " vs. " + << data_dim; + ++index; + } + } + + DenseIntElementsAttr num_segments_attr; + if (matchPattern(op.num_segments(), m_Constant(&num_segments_attr))) { + int64_t num_segments = (*num_segments_attr.begin()).getSExtValue(); + if (num_segments < 0) + return op.emitOpError("num of segments cannot be negative"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// VarIsInitializedOp +//===----------------------------------------------------------------------===// + +namespace { + +/// Erase VarIsInitializedOp operations with no uses. This op has side effect on +/// resources (read-only), but can still be deleted if it has zero uses. +struct EraseDeadVarIsInitializedOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(VarIsInitializedOp op, + PatternRewriter &rewriter) const override { + if (!op.use_empty()) return failure(); + rewriter.eraseOp(op); + return success(); + } +}; +} // end anonymous namespace. + +void VarIsInitializedOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + +//===----------------------------------------------------------------------===// +// VariableShapeOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(VariableShapeOp op) { + auto input_type = op.input().getType().cast(); + if (input_type.hasStaticShape() && input_type.getNumElements() != 1) + return op.emitOpError("requires input to have one resource"); + + auto resource_type = input_type.getElementType().cast(); + auto subtypes = resource_type.getSubtypes(); + switch (subtypes.size()) { + case 1: + return VerifyShapeOperandAndResult( + op, resource_type.getSubtypes().front(), op.getType()); + case 0: + return VerifyShapeOperandAndResult(op, Type(), op.getType()); + default: + return op.emitOpError( + "requires resource input type to have at most 1 subtype"); + } +} + +OpFoldResult VariableShapeOp::fold(ArrayRef operands) { + int width = + getType().cast().getElementType().getIntOrFloatBitWidth(); + auto resource_type = + getElementTypeOrSelf(getOperand().getType()).cast(); + if (resource_type.getSubtypes().empty()) return {}; + return ConvertShapeToAttr(resource_type.getSubtypes()[0], width); +} + +//===----------------------------------------------------------------------===// +// WhileOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(WhileOp op) { + auto cond_fn = op.cond_func(); + auto body_fn = op.body_func(); + if (!cond_fn) { + return op.emitOpError("cond refers to an undefined function : ") + << op.cond(); + } + if (!body_fn) { + return op.emitOpError("body refers to an undefined function : ") + << op.body(); + } + + auto cond_fn_type = cond_fn.getType(); + auto body_fn_type = body_fn.getType(); + + // Verify that the cond function has exactly one result. + if (cond_fn_type.getNumResults() != 1) + return op.emitOpError("requires cond function to have exactly one result"); + + SmallVector operands(op.getOperandTypes()); + + // Collect all the type lists for the op so that different pairs of type lists + // can be compared for the compatibility. + constexpr int kNumTypeLists = 5; + const std::array>, kNumTypeLists> + type_lists = {{ + {"operand", operands}, + {"body function result", body_fn_type.getResults()}, + {"result", op.getResultTypes()}, + {"cond function input", cond_fn_type.getInputs()}, + {"body function input", body_fn_type.getInputs()}, + }}; + + // A pair of type lists should be cast compatible with each other if one is + // converted to the another for a function call or assignment or there is a + // common source of inputs for both. Therefore, the While op requires the + // following pairs of type lists to be cast compatible for the tensor_cast + // operation: + // + // * Operands and cond inputs to call the cond function before the + // first iteration. + // * Operands and body inputs to call the body function for the first + // iteration if the cond functions returns True or equivalent result. + // * Operands and results to assign cond function arguments to op results if + // the cond function returns False or equivalent result. + // * All three pairs using cond inputs, body inputs and results as operand is + // a common source for all three. + // * Body result and cond inputs to call the cond function for the subsequent + // iterations. Similarly, Body result should be compatible with body inputs + // and op results. + // + // Note that the operands and body results need not be compatible as they are + // never converted from one to the another nor there is a common source + // tensors. Compatibility requirement is not transitive. + + for (int i = 0; i < kNumTypeLists; ++i) { + // Skip the first pair as the While op operands and body function results + // does not need to be compatible with each other. + for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) { + auto &a = type_lists[i]; + auto &b = type_lists[j]; + + int a_size = a.second.size(); + if (a_size != b.second.size()) + return op.emitOpError( + llvm::formatv("requires the number of {0}s to be equal to the " + "number of {1}s. Found {2} and {3}, respectively", + a.first, b.first, a_size, b.second.size())); + + for (int idx = 0; idx < a_size; ++idx) { + auto a_type = a.second[idx]; + auto b_type = b.second[idx]; + + if (!AreCastCompatible({a_type, b_type})) + return op.emitError(llvm::formatv( + "{0} type {1} is incompatible with {2} type {3} at index {4}", + a.first, a_type, b.first, b_type, idx)); + } + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// WhileOp canonicalization. +//===----------------------------------------------------------------------===// +void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert>(context); +} + +//===----------------------------------------------------------------------===// +// WhileRegionOp +//===----------------------------------------------------------------------===// +static LogicalResult Verify(WhileRegionOp op) { + // Verify that the condition generates a single tensor result. + YieldOp yield = cast(op.cond().front().getTerminator()); + if (yield.getNumOperands() != 1) + return op.emitOpError() + << "condition should have a single tensor result"; + + auto cond_type = yield.getOperand(0).getType().dyn_cast(); + if (!cond_type || !cond_type.getShape().equals({}) || + !cond_type.getElementType().isInteger(/*width=*/1)) + return op.emitOpError() + << "condition should have a single tensor result"; + + // The body result types should match while op result types. + if (failed(VerifyRegionResults(op, op.body(), "body"))) return failure(); + + // Both condition and body should have same number and type of operands as + // the WhileRegion inputs. + const int num_inputs = op.getNumOperands(); + auto block_inputs_match_op_inputs = [&](Region ®ion, + StringRef name) -> LogicalResult { + Block &block = region.front(); + if (block.getNumArguments() != num_inputs) + return op.emitOpError() + << name << " should have same number of inputs (" << num_inputs + << ") as " << WhileRegionOp::getOperationName() << " but has " + << block.getNumArguments() << " inputs"; + + for (auto types_idx : llvm::enumerate( + llvm::zip(op.getOperandTypes(), block.getArgumentTypes()))) { + auto op_input_type = std::get<0>(types_idx.value()); + auto block_input_type = std::get<1>(types_idx.value()); + if (!AreCastCompatible({block_input_type, op_input_type})) + return op.emitOpError(llvm::formatv( + "{0} input type {1} is incompatible with {2} " + "input type {3} at index {4}", + name, block_input_type, WhileRegionOp::getOperationName(), + op_input_type, types_idx.index())); + } + return success(); + }; + + if (failed(block_inputs_match_op_inputs(op.cond(), "condition")) || + failed(block_inputs_match_op_inputs(op.body(), "body"))) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// WhileRegionOp LoopLikeOpInterface +//===----------------------------------------------------------------------===// + +Region &WhileRegionOp::getLoopBody() { return body(); } + +bool WhileRegionOp::isDefinedOutsideOfLoop(Value value) { + // If the Op defining the value exists and the defining op is outside the + // scope of this WhileRegion, then we can infer that its defined outside. + // The defining Op is outside the scope of this WhileRegion if this + // WhileRegionOp is not an ancestor of the defining op in the parent chain. + Operation *def_op = value.getDefiningOp(); + return def_op && !getOperation()->isAncestor(def_op); +} + +LogicalResult WhileRegionOp::moveOutOfLoop( + llvm::ArrayRef ops) { + // Move the hoisted value to just before the while. + Operation *while_op = this->getOperation(); + for (auto op : ops) op->moveBefore(while_op); + return success(); +} + +//===----------------------------------------------------------------------===// +// WhileRegionOp canonicalization +//===----------------------------------------------------------------------===// +namespace { +// Eliminate values that pass through the WhileRegionOp body. +struct WhileRegionEliminatePassThrough + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileRegionOp while_op, + PatternRewriter &rewriter) const override { + // Replace values that simply passthrough the body with extern values. The + // block arguments of body and while match and so the corresponding cond + // argument can be easily found. + int old_num_operands = while_op.getNumOperands(); + int new_num_operands = old_num_operands; + auto &body_block = while_op.body().front(); + auto &cond_block = while_op.cond().front(); + auto &yield = *body_block.getTerminator(); + + // Bit mask indicating which operands will be removed. + SmallVector removed_operand(old_num_operands, false); + + for (int op_idx : llvm::seq(0, old_num_operands)) { + auto body_arg = body_block.getArgument(op_idx); + if (body_arg == yield.getOperand(op_idx)) { + // Replace the use of the passthrough value with the while operand + // in the body and condition regions, as well as the while output (if + // type match) + // TODO(jurahul): Use PatternRewriter API for IR modification. + auto value = while_op.getOperand(op_idx); + if (body_arg.getType() == value.getType()) + body_arg.replaceAllUsesWith(value); + + auto cond_arg = cond_block.getArgument(op_idx); + if (cond_arg.getType() == value.getType()) + cond_arg.replaceAllUsesWith(value); + + auto result = while_op.getResult(op_idx); + if (result.getType() == value.getType()) + result.replaceAllUsesWith(value); + } + + // Now check if the operand is unused in both regions as well as the + // result. If so, mark it for removal. + if (body_block.getArgument(op_idx).use_empty() && + cond_block.getArgument(op_idx).use_empty() && + while_op.getResult(op_idx).use_empty()) { + removed_operand[op_idx] = true; + new_num_operands--; + } + } + + if (new_num_operands == old_num_operands) return failure(); + + // Compress the operands, region arguments, and outputs. + SmallVector new_while_operands; + SmallVector new_result_types; + new_while_operands.reserve(new_num_operands); + new_result_types.reserve(new_num_operands); + + // Build new operands and result type. + int next_idx = 0; + for (int op_idx : llvm::seq(0, old_num_operands)) { + if (removed_operand[op_idx]) continue; + new_while_operands.push_back(while_op.getOperand(op_idx)); + new_result_types.push_back(while_op.getResult(op_idx).getType()); + next_idx++; + } + + // Create the new while operation. + auto new_while_op = + rewriter.create(while_op.getLoc(), new_result_types, + new_while_operands, while_op.getAttrs()); + + // Move region bodies to the new while. + rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(), + new_while_op.cond().end()); + rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(), + new_while_op.body().end()); + + auto &new_cond_block = new_while_op.cond().front(); + auto &new_body_block = new_while_op.body().front(); + auto &new_yield = *new_body_block.getTerminator(); + + // Build a vector of new results. Also patch up the region bodies and yield. + SmallVector new_results; + next_idx = 0; + for (int op_idx : llvm::seq(0, old_num_operands)) { + if (removed_operand[op_idx]) { + new_cond_block.eraseArgument(next_idx); + new_body_block.eraseArgument(next_idx); + new_yield.eraseOperand(next_idx); + new_results.push_back(nullptr); + } else { + new_results.push_back(new_while_op.getResult(next_idx++)); + } + } + + rewriter.replaceOp(while_op, new_results); + return success(); + } +}; + +} // anonymous namespace + +void WhileRegionOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// XdivyOp +//===----------------------------------------------------------------------===// + +void XdivyOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc.inc" + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h new file mode 100644 index 00000000000..761c06a475c --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +namespace mlir { +namespace TF { + +#define GET_OP_FWD_DEFINES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h.inc" + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_N_Z_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc new file mode 100644 index 00000000000..e87cc494a4a --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Identifier.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace mlir { +namespace TF { + +namespace { +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc" +#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_canonicalize.inc" +} // namespace + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.cc.inc" + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h new file mode 100644 index 00000000000..8586515edee --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_ + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h" + +namespace mlir { +namespace TF { + +#define GET_OP_FWD_DEFINES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_all_ops.h.inc" +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h.inc" + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_REMAINING_OPS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index edfc7feefd5..6883d0358ec 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -113,7 +113,8 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) { //===----------------------------------------------------------------------===// TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context) - : Dialect(/*name=*/"tf_saved_model", context) { + : Dialect(/*name=*/"tf_saved_model", context, + TypeID::get()) { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" @@ -337,6 +338,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) { if (auto attr = func.getArgAttrOfType( i, "tf_saved_model.bound_input")) { if (!unique_bound_inputs.insert(attr.getValue()).second) { + if (module.getAttr("tf_saved_model.under_construction")) continue; return func.emitError() << "duplicate 'tf_saved_model.bound_input' binding"; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index f488171d1e1..fc8e6f40f65 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -65,11 +66,73 @@ class OperandsSameAsResultsTypeOrRef } }; +// Verifies that op has the same operand and result element types (or type +// itself, if scalar) after resolving reference types (i.e., after converting +// reference types to their corresponding TensorFlow or standard types). +template +class SameOperandsAndResultElementTypeResolveRef + : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + Type element_type; + if (op->getNumResults() > 0) { + element_type = + mlir::TF::GetElementTypeOrSelfResolveRef(op->getResult(0).getType()); + } else if (op->getNumOperands() > 0) { + element_type = + mlir::TF::GetElementTypeOrSelfResolveRef(op->getOperand(0).getType()); + } else { + // Nothing to check. + return success(); + } + // Verify that all result element types are compatible to `element_type`. + for (const auto& result_type : op->getResultTypes()) { + if (mlir::TF::GetElementTypeOrSelfResolveRef(result_type) != + element_type) { + return op->emitOpError( + "requires compatible element types for all operands and results"); + } + } + // Verify that all operand element types are compatible to `element_type`. + for (const auto& operand_type : op->getOperandTypes()) { + if (mlir::TF::GetElementTypeOrSelfResolveRef(operand_type) != + element_type) { + return op->emitOpError( + "requires compatible element types for all operands and results"); + } + } + return success(); + } +}; + // Layout agnostic operations do not depend on the operands data layout (data // format), as and example all element wise operations are layout agnostic. template class LayoutAgnostic : public TraitBase {}; +// Trait to indicate operations that cannot be duplicated as they might carry +// certain state around within their implementations. +template +class CannotDuplicate : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + if (MemoryEffectOpInterface::hasNoEffect(op)) + return op->emitError( + "operations with no side effects cannot have CannotDuplicate trait"); + return success(); + } +}; + +// Coefficient-wise binary operation with implicit broadcasting support, for +// example tf.Sub operation. +template +class CwiseBinary : public TraitBase {}; + +// Coefficient-wise unary operation, for example tf.Sqrt operation. +template +class CwiseUnary : public TraitBase {}; + } // namespace TF } // namespace OpTrait } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index f352bc0eb47..125f6bb31df 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project namespace mlir { @@ -166,6 +167,17 @@ static inline Type GetDefaultTypeOf(TensorFlowRefType type) { return type.RemoveRef(); } +// Returns the element type if `type` is a `ShapedType` and the type itself +// otherwise, converting `TensorFlowRef` type to corresponding `TensorFlow` or +// standard type if necessary. +static inline Type GetElementTypeOrSelfResolveRef(Type type) { + Type element_type = mlir::getElementTypeOrSelf(type); + if (auto ref_type = element_type.dyn_cast()) { + element_type = ref_type.RemoveRef(); + } + return element_type; +} + #define HANDLE_TF_TYPE(tftype, enumerant, name) \ class tftype##Type : public detail::TensorFlowTypeImpl { \ public: \ diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 8597740a4ae..595bdce5be4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -143,6 +143,56 @@ func @testConcatCanonicalization(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) return %1 : tensor<2x2xi32> } +// CHECK-LABEL: testConcatCwiseUnary +func @testConcatCwiseUnary(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + + // CHECK: %[[CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %arg2) + // CHECK: %[[LOG1P:.*]] = "tf.Log1p"(%[[CONCAT]]) + // CHECK: return %[[LOG1P]] + %0 = "tf.Log1p"(%arg0) : (tensor) -> tensor + %1 = "tf.Log1p"(%arg1) : (tensor) -> tensor + %2 = "tf.ConcatV2"(%0, %1, %arg2) : (tensor, tensor, tensor) -> tensor + + return %2 : tensor +} + +// CHECK-LABEL: testConcatCwiseBinaryOnInnerDim +func @testConcatCwiseBinaryOnInnerDim(%arg0: tensor, + %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { + + // CHECK: %[[LHS_AXIS:.*]] = "tf.Const"() {value = dense<1> : tensor} + // CHECK: %[[RHS_AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} + + // CHECK: %[[LHS_CONCAT:.*]] = "tf.ConcatV2"(%arg0, %arg1, %[[LHS_AXIS]]) + // CHECK: %[[RHS_CONCAT:.*]] = "tf.ConcatV2"(%arg2, %arg3, %[[RHS_AXIS]]) + + // CHECK: %[[MUL:.*]] = "tf.Mul"(%[[LHS_CONCAT]], %[[RHS_CONCAT]]) + // CHECK-SAME: (tensor, tensor<2xf32>) -> tensor + // CHECK: return %[[MUL]] + + %0 = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %1 = "tf.Mul"(%arg0, %arg2) : (tensor, tensor) -> tensor + %2 = "tf.Mul"(%arg1, %arg3) : (tensor, tensor) -> tensor + %3 = "tf.ConcatV2"(%1, %2, %0) : (tensor, tensor, tensor) -> tensor + + return %3 : tensor +} + +// CHECK-LABEL: testConcatCwiseBinaryInvalidInnerDim +func @testConcatCwiseBinaryInvalidInnerDim(%arg0: tensor, + %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { + // Each individual binary operation has an implicit broadcast that will be + // lost if we would reorder them with the concat. + + // CHECK: "tf.ConcatV2"(%1, %2, %0) + %0 = "tf.Const"() { value = dense<1> : tensor } : () -> tensor + %1 = "tf.Mul"(%arg0, %arg2) : (tensor, tensor) -> tensor + %2 = "tf.Mul"(%arg1, %arg3) : (tensor, tensor) -> tensor + %3 = "tf.ConcatV2"(%1, %2, %0) : (tensor, tensor, tensor) -> tensor + + return %3 : tensor +} + // CHECK-LABEL: testLogOfSoftmax func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> @@ -377,6 +427,86 @@ func @testRedundantReshape(%arg0: tensor<4x4xi32>) -> tensor<2x8xi32> { // CHECK: return %1 : tensor<2x8xi32> } +// CHECK-LABEL: testReshapeToSelfShape +func @testReshapeToSelfShape(%arg0: tensor) -> tensor { + %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<2xi32> + %1 = "tf.Reshape"(%arg0, %0) : (tensor, tensor<2xi32>) -> tensor + + // CHECK: return %arg0 : tensor + return %1: tensor +} + +// CHECK-LABEL: func @testReshapeNoOp +func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x4xf32> { + %0 = "tf.Reshape"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2xi32>) -> tensor<2x4xf32> + + // CHECK: return %arg0 + return %0 : tensor<2x4xf32> +} + +// CHECK-LABEL: func @testPackShapeComputation +func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) { + // Test dimensions sizes. + %d1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %d2 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + + // Slice bounds. + %0 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + + // Fold pack operation if it computes the input tensor shape: + // + // %shape = tf.Shape(%arg) // [? x ...] + // %dim0 = tf.StridedSlice(%shape, 0, 1, 1) // get unknown dim0 value + // %pack = tf.Pack(dim0, ...) { axis = 0 } // [? x ...] + // + // Where `...` are some statically known dimensions. In this case %pack can be + // replace with a %shape. This is a common pattern in models with a dynamic + // batch size. + + // Test Rank 2 + // CHECK: %[[SHAPE0:.*]] = "tf.Shape" + %3 = "tf.Shape"(%arg0) : (tensor) -> tensor<2xi32> + %4 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %5 = "tf.Pack"(%4, %d1) {axis = 0 : i64} : (tensor, tensor) -> tensor<2xi32> + %6 = "tf.Reshape"(%arg0, %5) : (tensor, tensor<2xi32>) -> tensor + + // Test Rank 3. + // CHECK: %[[SHAPE1:.*]] = "tf.Shape" + %7 = "tf.Shape"(%arg1) : (tensor) -> tensor<3xi32> + %8 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %9 = "tf.Pack"(%8, %d1, %d2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> + %10 = "tf.Reshape"(%arg1, %9) : (tensor, tensor<3xi32>) -> tensor + + // Packed dimensions have different order from the reshape operand: + // [?, 1, 2] vs [?, 2, 1] + %14 = "tf.StridedSlice"(%7, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %15 = "tf.Pack"(%14, %d2, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> + // CHECK: %[[PACK0:.*]] = "tf.Pack" + + // StridedSlice takes second dimension from the shape: + // begin = [1], end = [2], stride = [1] + %17 = "tf.StridedSlice"(%7, %1, %2, %1) {shrink_axis_mask = 1 : i64} : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %18 = "tf.Pack"(%17, %d1, %d2) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> + // CHECK: %[[PACK1:.*]] = "tf.Pack" + + // Packed dimensions have higher rank than the reshape operand: + // [?, 1] vs [?, 1, 1] + %20 = "tf.StridedSlice"(%3, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + %21 = "tf.Pack"(%20, %d1, %d1) {axis = 0 : i64} : (tensor, tensor, tensor) -> tensor<3xi32> + // CHECK: %[[PACK2:.*]] = "tf.Pack" + + // Make sure a dynamic ranked shape doesn't crash the "canonicalize" pass + %23 = "tf.Shape"(%arg2) : (tensor<*xf32>) -> tensor<*xi32> + %24 = "tf.StridedSlice"(%23, %0, %1, %1) {shrink_axis_mask = 1 : i64} : (tensor<*xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<*xi32> + %25 = "tf.Pack"(%24, %d1) {axis = 0 : i64} : (tensor<*xi32>, tensor) -> tensor<*xi32> + // CHECK: %[[PACK3:.*]] = "tf.Pack" + + // CHECK: return %[[SHAPE0]], %[[SHAPE1]], %[[PACK0]], %[[PACK1]], %[[PACK2]], %[[PACK3]] + return %5, %9, %15, %18, %21, %25 : tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32> +} + // CHECK-LABEL: testSelectScalarPred func @testSelectScalarPred(%arg0: tensor, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> { // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> @@ -512,6 +642,18 @@ func @testRealDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32 // CHECK: return %1 } +// CHECK-LABEL: testRealDivWithConstDivisor +func @testRealDivWithConstDivisor(%arg0: tensor<8x2xf32>) -> tensor<8x2xf32> { + %0 = "tf.Const"() {value = dense<[2.0, 4.0]> : tensor<2xf32>} : () -> tensor<2xf32> + %1 = "tf.RealDiv"(%arg0, %0) : (tensor<8x2xf32>, tensor<2xf32>) -> tensor<8x2xf32> + return %1: tensor<8x2xf32> + + // CHECK: %0 = "tf.Const" + // CHECK-SAME: value = dense<[5.000000e-01, 2.500000e-01] + // CHECK: %1 = "tf.Mul"(%arg0, %0) + // CHECK: return %1 +} + // CHECK-LABEL: testTruncateDivWithSqrtDivisor func @testTruncateDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> @@ -663,6 +805,27 @@ func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex>) { return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex> } +// CHECK-LABEL: foldIf +func @foldIf(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor) { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + %1 = "tf.Const"() {value = dense : tensor} : () -> tensor + + // CHECK: %0 = "tf.PartitionedCall"(%arg0, %arg1) + // CHECK-SAME: device = "noodle" + // CHECK-SAME: f = @sub + %2 = "tf.If"(%0, %arg0, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf.shape<>], device = "noodle", is_stateless = true} : (tensor, tensor, tensor) -> tensor + // CHECK: %1 = "tf.StatefulPartitionedCall"(%0, %arg1) + // CHECK-SAME: _underscore_attr = "something" + // CHECK-SAME: f = @add + %3 = "tf.If"(%1, %2, %arg1) {then_branch = @add, else_branch = @sub, output_shapes = [#tf.shape<>], _underscore_attr = "something", is_stateless = false} : (tensor, tensor, tensor) -> tensor + + // CHECK: %2 = "tf.If" + %4 = "tf.If"(%arg2, %3, %arg1) {then_branch = @add, else_branch = @sub, is_stateless = false} : (tensor, tensor, tensor) -> tensor + + // CHECK: return %2 + return %4 : tensor +} + // CHECK-LABEL: foldCase func @foldCase(%arg0: tensor, %arg1: tensor) -> (tensor) { %2 = constant dense<1> : tensor @@ -872,3 +1035,36 @@ func @testWhileRegionUnusedValue(%arg0 : tensor<*xf32>, %arg1 : tensor, %ar // CHECK: return %[[WHILE_OUT]]#0 : tensor<*xf32> return %0#0 : tensor<*xf32> } + +// Check that output_shapes attribute is removed for tf.If +func @testIfThen(tensor<*xf32>) -> tensor<*xf32> +func @testIfElse(tensor<*xf32>) -> tensor<*xf32> +// CHECK-LABEL: func @testIfDropOutputShapes +func @testIfDropOutputShapes(tensor, tensor<2xf32>) -> tensor<2xf32> { +^bb0(%arg0: tensor, %arg1: tensor<2xf32>): + // CHECK: "tf.If" + // CHECK-NOT: output_shapes + %1 = "tf.If"(%arg0, %arg1) { + then_branch = @testIfThen, else_branch = @testIfElse, is_stateless = false, output_shapes = [#tf.shape<>] + } : (tensor, tensor<2xf32>) -> tensor<2xf32> + + return %1 : tensor<2xf32> +} + +// Check that output_shapes attribute is removed for tf.Whileß +func @testWhileCond(tensor<*xf32>) -> (tensor) +func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) +// CHECK-LABEL: func @testWhileDropOutputShapes +func @testWhileDropOutputShapes(tensor<*xf32>) -> (tensor<*xf32>) { +^bb0(%arg0: tensor<*xf32>): + // CHECK: "tf.While" + // CHECK-NOT: output_shapes + %1 = "tf.While"(%arg0) { + cond = @testWhileCond, + body = @testWhileBody, + is_stateless = false, + output_shapes = [#tf.shape<>] + } : (tensor<*xf32>) -> (tensor<*xf32>) + + return %1 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir index 7b8c998bcf1..b86815dbe57 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/constant-fold.mlir @@ -443,7 +443,7 @@ func @DontRemoveTrivialMul(%arg0: tensor<1x6x8x1xf32>) -> tensor<1x6x8x1xf32> { // CHECK: return %[[RESULT]] : tensor<1x6x8x1xf32> } -// Do not fold if total result size is large (>128 KB) and more than 2 times +// Do not fold if total result size is large (>256 KB) and more than 2 times // the size of operands. // LINT.IfChange(folding-policy-test) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir index 130887555b0..e7430993755 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/einsum.mlir @@ -98,6 +98,31 @@ func @einsum_transposereduceddim(%arg0: tensor<2x5x7xf32>, %arg1: tensor<2x5x3x7 // CHECK: return %[[v3]] : tensor<2x5x3xf32> } +func @einsum_fourdreducelast(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x3x5x13xf32>) -> tensor<2x7x5x13xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "acbe,aecd->abcd"}: (tensor<2x5x7x3xf32>, tensor<2x3x5x13xf32>) -> tensor<2x7x5x13xf32> + return %0 : tensor<2x7x5x13xf32> + // CHECK-LABEL: einsum_fourdreducelast + // CHECK: %[[cst:.*]] = constant dense<[0, 2, 1, 3]> : tensor<4xi32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg1, %[[cst]]) : (tensor<2x3x5x13xf32>, tensor<4xi32>) -> tensor<2x5x3x13xf32> + // CHECK: %[[v1:.*]] = "tf.BatchMatMulV2"(%arg0, %[[v0]]) {adj_x = false, adj_y = false} : (tensor<2x5x7x3xf32>, tensor<2x5x3x13xf32>) -> tensor<2x5x7x13xf32> + // CHECK: %[[v2:.*]] = "tf.Transpose"(%[[v1]], %[[cst]]) : (tensor<2x5x7x13xf32>, tensor<4xi32>) -> tensor<2x7x5x13xf32> + // CHECK: return %[[v2]] : tensor<2x7x5x13xf32> +} + +func @einsum_fourdtransposeall(%arg0: tensor<2x5x7x3xf32>, %arg1: tensor<2x11x7x3xf32>) -> tensor<2x7x11x5xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "aecd,abcd->acbe"}: (tensor<2x5x7x3xf32>, tensor<2x11x7x3xf32>) -> tensor<2x7x11x5xf32> + return %0 : tensor<2x7x11x5xf32> + // CHECK-LABEL: einsum_fourdtransposeall + // CHECK: %[[cst:.*]] = constant dense<[0, 2, 1, 3]> : tensor<4xi32> + // CHECK: %[[cst_1:.*]] = constant dense<[0, 2, 3, 1]> : tensor<4xi32> + // CHECK: %[[cst_2:.*]] = constant dense<[0, 1, 3, 2]> : tensor<4xi32> + // CHECK: %[[v0:.*]] = "tf.Transpose"(%arg0, %[[cst]]) : (tensor<2x5x7x3xf32>, tensor<4xi32>) -> tensor<2x7x5x3xf32> + // CHECK: %[[v1:.*]] = "tf.Transpose"(%arg1, %[[cst_1]]) : (tensor<2x11x7x3xf32>, tensor<4xi32>) -> tensor<2x7x3x11xf32> + // CHECK: %[[v2:.*]] = "tf.BatchMatMulV2"(%[[v0]], %[[v1]]) {adj_x = false, adj_y = false} : (tensor<2x7x5x3xf32>, tensor<2x7x3x11xf32>) -> tensor<2x7x5x11xf32> + // CHECK: %[[v3:.*]] = "tf.Transpose"(%[[v2]], %[[cst_2]]) : (tensor<2x7x5x11xf32>, tensor<4xi32>) -> tensor<2x7x11x5xf32> + // CHECK: return %[[v3]] : tensor<2x7x11x5xf32> +} + func @einsum_no_match(%arg0: tensor<4x5xf32>, %arg1: tensor<5xf32>) -> tensor<4xf32> { %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ij,j->i"}: (tensor<4x5xf32>, tensor<5xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/executor_tpuv1_inline_tpu_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/executor_tpuv1_inline_tpu_island.mlir index f45f0a435c3..b7bdf505a85 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/executor_tpuv1_inline_tpu_island.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/executor_tpuv1_inline_tpu_island.mlir @@ -35,11 +35,11 @@ module { } // CHECK-NOT: _tpu_v1_compat_outlined module @_tpu_v1_compat_outlined { - func @_tpu_v1_compat_outlined_func0(%arg0: tensor) -> tensor { + func @_tpu_v1_compat_outlined_func0(%arg0: tensor) -> tensor attributes {sym_visibility = "nested"} { %0 = "tf.opA"(%arg0) : (tensor) -> tensor return %0 : tensor } - func @_tpu_v1_compat_outlined_func1(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + func @_tpu_v1_compat_outlined_func1(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) attributes {sym_visibility = "nested"} { %0 = "tf.opA"(%arg0) : (tensor) -> tensor %1 = "tf.opA"(%0) : (tensor) -> tensor %2 = "tf.SomeOp"(%arg0, %arg1) : (tensor, tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/while_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/while_op.mlir index 8c174a7cfaf..6724033d292 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/while_op.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_island_inlining/while_op.mlir @@ -12,7 +12,7 @@ module { return %0#0 : tensor } module @_tpu_v1_compat_outlined { - func @_tpu_v1_compat_outlined_func0(%arg0: tensor) -> (tensor, tensor, tensor, tensor) { + func @_tpu_v1_compat_outlined_func0(%arg0: tensor) -> (tensor, tensor, tensor, tensor) attributes {sym_visibility = "nested"} { "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1 : i64, topology = "topology"} : () -> () %0 = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor %1 = "tf.While"(%0) {body = @while_body_with_cluster_attr, cond = @while_cond_with_cluster_attr, is_stateless = false, name = "A", parallel_iterations = 10 : i64} : (tensor) -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir index a7e9b22d72b..c8c82c5c08f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir @@ -1,5 +1,6 @@ -// RUN: tf-opt %s -tf-functional-control-flow-to-regions -split-input-file | FileCheck %s --dump-input=fail +// RUN: tf-opt %s -tf-functional-control-flow-to-regions -split-input-file | FileCheck %s +// Simple If // CHECK: func @testIf1Then{{.+}} // CHECK: func @testIf1Else{{.+}} func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> @@ -8,7 +9,8 @@ func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> // CHECK-LABEL: func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { %0 = "tf.If"(%arg0, %arg1) { - then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false + then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false, + _attr0 = 10, _attr1 = true, attr2 = "hello" } : (tensor, tensor<*xf32>) -> tensor<*xf32> // CHECK: "tf.IfRegion" @@ -16,12 +18,19 @@ func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { // CHECK: "tf.Yield"([[Result0]]) // CHECK: [[Result1:%.*]] = call @testIf1Else // CHECK: "tf.Yield"([[Result1]]) + // CHECK: _attr0 = 10 + // CHECK-SAME: _attr1 = true + // CHECK-NOT: attr2 = + // CHECK-NOT: else_branch + // CHECK-SAME: is_stateless = false + // CHECK-NOT: then_branch + // CHECK-SAME: } return %0 : tensor<*xf32> } // ----- -// With mismatching input types +// If with mismatching input types // CHECK: func @testIf1Then{{.+}} // CHECK: func @testIf1Else{{.+}} @@ -46,7 +55,7 @@ func @testIf2Result(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // ----- -// No inputs, some outputs +// If with no inputs, some outputs // CHECK: func @testIf1Then{{.+}} // CHECK: func @testIf1Else{{.+}} func @testIf1Then() -> tensor<*xf32> @@ -68,7 +77,7 @@ func @testIfNoInputs(%arg0: tensor) -> tensor<2xf32> { // ----- -// No outputs, some inputs +// If with no outputs, some inputs // CHECK: func @testIf1Then{{.+}} // CHECK: func @testIf1Else{{.+}} func @testIf1Then(tensor<*xf32>) -> () @@ -91,7 +100,8 @@ func @testIfNoResult(%arg0: tensor, %arg1: tensor<2xf32>) -> () { } // ----- -// No outputs, No inputs + +// If with no outputs, No inputs // CHECK: func @testIf1Then{{.+}} // CHECK: func @testIf1Else{{.+}} func @testIf1Then() -> () @@ -111,3 +121,82 @@ func @testIfNoInputAndNoResult(%arg0: tensor) -> () { return } +// ----- + +// Simple While +func @testWhileCond(tensor<*xf32>) -> (tensor) +func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) + +// CHECK-LABEL: func @testWhileResult +func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { +^bb0(%arg0: tensor<*xf32>): + %1 = "tf.While"(%arg0) { + cond = @testWhileCond, + body = @testWhileBody, + is_stateless = true, + _attr0 = 10, _attr1 = true, attr2 = "hello" + } : (tensor<*xf32>) -> (tensor<*xf32>) + + // CHECK: [[Result0:%.*]] = "tf.WhileRegion" + // CHECK: [[Result1:%.*]] = call @testWhileCond + // CHECK: "tf.Yield"([[Result1]]) + // CHECK: [[Result2:%.*]] = call @testWhileBody + // CHECK: "tf.Yield"([[Result2]]) + // CHECK: _attr0 = 10 + // CHECK-SAME: _attr1 = true + // CHECK-NOT: attr2 = + // CHECK-NOT: cond = + // CHECK-NOT: body = + // CHECK-SAME: is_stateless = true + // CHECK: return [[Result0]] + return %1 : tensor<*xf32> +} + +// ----- + +// While with no inputs & outputs +func @testWhileCond() -> (tensor) +func @testWhileBody() -> () + +// CHECK-LABEL: func @testWhileResultNoIO +func @testWhileResultNoIO() -> () { + "tf.While"() { + cond = @testWhileCond, + body = @testWhileBody, + is_stateless = false + } : () -> () + + // CHECK: "tf.WhileRegion" + // CHECK: [[Result1:%.*]] = call @testWhileCond + // CHECK: "tf.Yield"([[Result1]]) + // CHECK: call @testWhileBody + // CHECK: "tf.Yield"() + return +} + +// ----- + +// While with type mismatch +func @testWhileCond(tensor<4xf32>) -> (tensor) +func @testWhileBody(tensor<4xf32>) -> (tensor<4xf32>) + +// CHECK-LABEL: func @testWhileResult +func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { +^bb0(%arg0: tensor<*xf32>): + %1 = "tf.While"(%arg0) { + cond = @testWhileCond, + body = @testWhileBody, + is_stateless = false + } : (tensor<*xf32>) -> (tensor<*xf32>) + + // CHECK: [[Result0:%.*]] = "tf.WhileRegion" + // CHECK: [[ResultCast0:%.*]] = "tf.Cast" + // CHECK: [[Result1:%.*]] = call @testWhileCond([[ResultCast0]]) + // CHECK: "tf.Yield"([[Result1]]) + // CHECK: [[ResultCast1:%.*]] = "tf.Cast" + // CHECK: [[Result2:%.*]] = call @testWhileBody([[ResultCast1]]) + // CHECK: "tf.Yield"([[Result2]]) + // CHECK: return [[Result0]] + return %1 : tensor<*xf32> +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/case_op.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/case_op.pbtxt new file mode 100644 index 00000000000..1372ad71283 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/case_op.pbtxt @@ -0,0 +1,261 @@ +# RUN: tf-mlir-translate -graphdef-to-splatted-mlir %s -o - | FileCheck %s + +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "Const_1" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "indexed_case" + op: "StatelessCase" + input: "Const_1" + input: "Const" + attr { + key: "Tin" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_INT32 + } + } + } + attr { + key: "_lower_using_switch_merge" + value { + b: true + } + } + attr { + key: "_read_only_resource_inputs" + value { + list { + } + } + } + attr { + key: "branches" + value { + list { + func { + name: "indexed_case_branch0_4" + } + func { + name: "indexed_case_branch1_5" + } + } + } + } + attr { + key: "output_shapes" + value { + list { + shape { + } + } + } + } +} +node { + name: "indexed_case/Identity" + op: "Identity" + input: "indexed_case" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +library { + function { + signature { + name: "indexed_case_branch0_4" + input_arg { + name: "add_const" + type: DT_INT32 + } + output_arg { + name: "add" + type: DT_INT32 + } + } + node_def { + name: "add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + experimental_debug_info { + original_node_names: "add/y" + } + } + node_def { + name: "add_0" + op: "AddV2" + input: "add_const" + input: "add/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + experimental_debug_info { + original_node_names: "add" + } + } + ret { + key: "add" + value: "add_0:z:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + } + } + function { + signature { + name: "indexed_case_branch1_5" + input_arg { + name: "add_const" + type: DT_INT32 + } + output_arg { + name: "add" + type: DT_INT32 + } + } + node_def { + name: "add/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } + experimental_debug_info { + original_node_names: "add/y" + } + } + node_def { + name: "add_0" + op: "AddV2" + input: "add_const" + input: "add/y:output:0" + attr { + key: "T" + value { + type: DT_INT32 + } + } + experimental_debug_info { + original_node_names: "add" + } + } + ret { + key: "add" + value: "add_0:z:0" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + } + } +} +versions { + producer: 486 + min_consumer: 12 +} + +# CHECK: tf.Case +# CHECK-SAME: is_stateless diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt index cf08d55b3cb..304429c8783 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-custom-operation.pbtxt @@ -54,5 +54,5 @@ versions { # the names are matching between the function definition and the uses / call # site (a numerical suffix may be appended). -# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, f = @foo0} +# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, device = "", f = @foo0} # CHECK: func @foo0 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt index fa6f63e27a5..f954657765a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-call.pbtxt @@ -34,6 +34,12 @@ node { b: true } } + attr { + key: "_tpu_replicate" + value { + s: "cluster" + } + } } library { function { @@ -62,4 +68,4 @@ library { } # CHECK: func @main -# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, f = @test_func_name0} +# CHECK: "tf.LegacyCall"(%arg0) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @test_func_name0} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt index 8cf6d4ed5d5..326e7b1ecd4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-name-bug.pbtxt @@ -121,8 +121,8 @@ versions { # Verify that functions from the library are properly imported. # CHECK-LABEL: func @main() { -# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110} -# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111} +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo110} +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo111} # CHECK-LABEL: func @foo110() attributes {sym_visibility = "private"} # CHECK-LABEL: func @foo111() attributes {sym_visibility = "private"} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt index eb358d52b26..7cb7ac7e008 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-function-resource-args.pbtxt @@ -88,7 +88,7 @@ library { # CHECK: tf_executor.graph # CHECK: "tf.VarHandleOp"() # CHECK: "tf.LegacyCall" -# CHECK-SAME: {_disable_call_shape_inference = true, f = @test_func_name0} +# CHECK-SAME: {_disable_call_shape_inference = true, device = "", f = @test_func_name0} # CHECK: tf_executor.fetch # CHECK: return # CHECK: func @test_func_name0 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt index 55a76b1b668..53e951473d0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/graph-library.pbtxt @@ -54,10 +54,10 @@ versions { # Verify that functions from the library are properly imported. # CHECK-LABEL: func @main() { -# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, f = @foo0} -# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0} +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, device = "", f = @foo0} +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0} # CHECK-LABEL: func @foo0() attributes {sym_visibility = "private"} -# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0} +# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0} # CHECK-LABEL: func @bar0() attributes {sym_visibility = "private"} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir b/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir new file mode 100644 index 00000000000..d8903846158 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/guarantee-all-funcs-one-use.mlir @@ -0,0 +1,54 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-guarantee-all-funcs-one-use | FileCheck %s + +// ----- +// Basic test. +// CHECK-LABEL: func @f +func @f() { + // CHECK: call @g() : () -> () + // CHECK: call @[[NEWG:.+]]() : () -> () + call @g() : () -> () + call @g() : () -> () + return +} + +// CHECK: func @g() +// CHECK: func @[[NEWG]]() attributes {sym_visibility = "private"} +func @g() { + return +} + +// ----- +// Transitive callees. +// CHECK-LABEL: func @f +// 2 copies of @g +// CHECK-DAG: func @g{{.*}} +// CHECK-DAG: func @g{{.*}} +// 4 copies of @h +// CHECK-DAG: func @h{{.*}} +// CHECK-DAG: func @h{{.*}} +// CHECK-DAG: func @h{{.*}} +// CHECK-DAG: func @h{{.*}} +func @f() { + call @g() : () -> () + call @g() : () -> () + return +} + +func @g() { + call @h() : () -> () + call @h() : () -> () + return +} + +func @h() { + return +} + +// ----- +// Handle error case of infinite recursion. +// expected-error @+1 {{reached cloning limit}} +func @f() attributes {sym_visibility = "private"} { + call @f() : () -> () + call @f() : () -> () + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/init_text_file_to_import.mlir b/tensorflow/compiler/mlir/tensorflow/tests/init_text_file_to_import.mlir new file mode 100644 index 00000000000..6a9581b0e44 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/init_text_file_to_import.mlir @@ -0,0 +1,14 @@ +// RUN: tf-opt -tf-init-text-file-to-import-test %s | FileCheck %s + +// Tests that the tf.InitializeTableFromTextFileV2 op are inlined. + +func @init_all_tables() { + %cst = constant dense<"%FILE_PLACEHOLDER"> : tensor + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = -1 : i64, vocab_size = -1 : i64} : (tensor, tensor) -> () + return + // CHECK: [[CST:%.*]] = constant dense<["apple", "banana", "grape"]> : tensor<3x!tf.string> + // CHECK: [[CST_0:%.*]] = constant dense<[0, 1, 2]> : tensor<3xi64> + // CHECK: [[VAL:%.*]] = "tf.HashTableV2"() + // CHECK: "tf.LookupTableImportV2"([[VAL]], [[CST]], [[CST_0]]) +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/init_text_file_to_import_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/init_text_file_to_import_invalid.mlir new file mode 100644 index 00000000000..05afe1cc27f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/init_text_file_to_import_invalid.mlir @@ -0,0 +1,53 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -tf-init-text-file-to-import %s | FileCheck %s + +// Tests that the given vocabulary file does not exist. + +func @init_all_tables() { + %cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + // expected-error @+1 {{'tf.InitializeTableFromTextFileV2' op failed to open vocabulary file (vocab_file_does_not_exist.txt): cannot open input file 'vocab_file_does_not_exist.txt': No such file or directory}} + "tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = -1 : i64, vocab_size = -1 : i64} : (tensor, tensor) -> () + return +} + +// ----- + +// Tests that the tf.InitializeTableFromTextFileV2 op is not converted since +// unsupported key_index, -1. + +func @init_all_tables() { + %cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -1 : i64, value_index = -1 : i64, vocab_size = -1 : i64} : (tensor, tensor) -> () + return + // CHECK: [[VAL:%.*]] = "tf.HashTableV2"() + // CHECK: tf.InitializeTableFromTextFileV2" +} + +// ----- + +// Tests that the tf.InitializeTableFromTextFileV2 op is not converted since +// unsupported value_index, 0. + +func @init_all_tables() { + %cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = 0 : i64, vocab_size = -1 : i64} : (tensor, tensor) -> () + return + // CHECK: [[VAL:%.*]] = "tf.HashTableV2"() + // CHECK: tf.InitializeTableFromTextFileV2" +} + +// ----- + +// Tests that the tf.InitializeTableFromTextFileV2 op is not converted since +// unsupported vocab_size, 1. + +func @init_all_tables() { + %cst = constant dense<"vocab_file_does_not_exist.txt"> : tensor + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf.string, shared_name = "hash_table_/tmp/vocab.txt_-2_-1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.InitializeTableFromTextFileV2"(%0, %cst) {delimiter = " ", device = "", key_index = -2 : i64, value_index = -1 : i64, vocab_size = 1 : i64} : (tensor, tensor) -> () + return + // CHECK: [[VAL:%.*]] = "tf.HashTableV2"() + // CHECK: tf.InitializeTableFromTextFileV2" +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir b/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir index 5f4bffcc7c2..7e583d0425a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/inlining.mlir @@ -2,7 +2,7 @@ // Test that simple TF operations can be inlined. -func @inline_simple_callee() -> tensor<2xi32> { +func @inline_simple_callee() -> tensor<2xi32> attributes {sym_visibility = "private"} { %cst = "tf.Const"() { value = dense<2> : tensor<2xi32> } : () -> tensor<2xi32> return %cst : tensor<2xi32> } @@ -18,7 +18,7 @@ func @inline_simple() -> tensor<2xi32> { // Check that TF call operations can be inlined, even when the shape of the // argument or result is different than the called function. -func @inline_shape_cast_callee(%arg : tensor<*xi32>) -> tensor<*xi32> { +func @inline_shape_cast_callee(%arg : tensor<*xi32>) -> tensor<*xi32> attributes {sym_visibility = "private"} { return %arg : tensor<*xi32> } @@ -34,7 +34,12 @@ func @inline_shape_cast(%arg: tensor<2xi32>) -> tensor<2xi32> { // Check that functions can be inlined into islands. -func @inline_into_island_multi_block_callee() -> tensor<2xi32> { +func @inline_simple_callee1() -> tensor<2xi32> attributes {sym_visibility = "private"} { + %cst = "tf.Const"() { value = dense<2> : tensor<2xi32> } : () -> tensor<2xi32> + return %cst : tensor<2xi32> +} + +func @inline_into_island_multi_block_callee() -> tensor<2xi32> attributes {sym_visibility = "private"} { br ^bb1 ^bb1: @@ -48,7 +53,7 @@ func @inline_into_island() -> (tensor<2xi32>, tensor<2xi32>) { %1:3 = tf_executor.island { // Single block regions may be inlined. // CHECK: %[[CST:.*]] = "tf.Const" - %result = "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @inline_simple_callee} : () -> tensor<2xi32> + %result = "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @inline_simple_callee1} : () -> tensor<2xi32> // Multi block regions may not. // CHECK-NEXT: %[[CALL:.*]] = "tf.StatefulPartitionedCall" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index 3215055a249..e11474c0755 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -test-tf-lower-tf | FileCheck %s +// RUN: tf-opt %s -test-tf-lower-tf | FILECHECK_OPTS="" FileCheck %s // CHECK-LABEL: invert_permutation func @invert_permutation(%arg0: tensor<5xi32>) -> tensor<5xi32> { @@ -353,8 +353,16 @@ func @ZerosLike_variant(%arg0: tensor>>) -> tensor>> } -// CHECK-LABEL: func @addN -func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-LABEL: func @addN_2 +func @addN_2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1) + // return %[[SUM0]] + %0 = "tf.AddN"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @addN_3 +func @addN_3(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1) // CHECK: %[[SUM1:.*]] = "tf.AddV2"(%[[SUM0]], %arg2) // return %[[SUM1]] @@ -362,6 +370,27 @@ func @addN(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> return %0 : tensor<*xf32> } +// CHECK-LABEL: func @addN_4 +func @addN_4(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1) + // CHECK: %[[SUM1:.*]] = "tf.AddV2"(%arg2, %arg3) + // CHECK: %[[SUM2:.*]] = "tf.AddV2"(%[[SUM0]], %[[SUM1]]) + // return %[[SUM2]] + %0 = "tf.AddN"(%arg0, %arg1, %arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// CHECK-LABEL: func @addN_5 +func @addN_5(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>, %arg3: tensor<*xf32>, %arg4: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SUM0:.*]] = "tf.AddV2"(%arg0, %arg1) + // CHECK: %[[SUM1:.*]] = "tf.AddV2"(%arg2, %arg3) + // CHECK: %[[SUM2:.*]] = "tf.AddV2"(%[[SUM0]], %[[SUM1]]) + // CHECK: %[[SUM3:.*]] = "tf.AddV2"(%[[SUM2]], %arg4) + // return %[[SUM3]] + %0 = "tf.AddN"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: func @addN_variant func @addN_variant(%arg0: tensor>>, %arg1: tensor>>, %arg2: tensor>>) -> tensor>> { // CHECK: tf.AddN @@ -371,9 +400,7 @@ func @addN_variant(%arg0: tensor>>, %arg1: tensor) -> tensor<2x2xf32> { - // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> - // CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32> - // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) + // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2xf32>, tensor<2xf32>, tensor) -> tensor<2x2xf32> // CHECK: return %[[RESULT]] @@ -411,9 +438,7 @@ func @DynamicStitch_uint8(%arg0: tensor<2x2xui8>) -> tensor<2x2xui8> { // CHECK-LABEL: func @DynamicStitch_scalar_item func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<-1> : tensor<1xi64>} : () -> tensor<1xi64> - // CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2xf32>, tensor<1xi64>) -> tensor<2xf32> - // CHECK-DAG: %[[ITEMS]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2xf32>) -> (tensor, tensor) + // CHECK-DAG: %[[ITEMS]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2xf32>) -> (tensor, tensor) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK-DAG: %[[RESULT]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor, tensor, tensor) -> tensor<2xf32> // CHECK: return %[[RESULT]] @@ -425,9 +450,7 @@ func @DynamicStitch_scalar_item(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-LABEL: func @DynamicStitch_matrix_item func @DynamicStitch_matrix_item(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x2xf32> { - // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2, 2]> : tensor<3xi64>} : () -> tensor<3xi64> - // CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x2x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> - // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) + // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2x2xf32>) -> (tensor<2x2xf32>, tensor<2x2xf32>) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[ITEMS]]#0, %[[AXIS]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor) -> tensor<2x2x2xf32> // CHECK: return %[[RESULT]] @@ -446,9 +469,7 @@ func @DynamicStitch_dynamic(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tenso // CHECK-LABEL: func @DynamicStitch_duplicates func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> { - // CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() {value = dense<[-1, 2]> : tensor<2xi64>} : () -> tensor<2xi64> - // CHECK-DAG: %[[INP:.*]] = "tf.Reshape"(%arg0, %[[SHAPE]]) : (tensor<2x2xf32>, tensor<2xi64>) -> tensor<2x2xf32> - // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%[[INP]]) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) + // CHECK-DAG: %[[ITEMS:.*]]:2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<2x2xf32>) -> (tensor<2xf32>, tensor<2xf32>) // CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() {value = dense<0> : tensor} : () -> tensor // CHECK-DAG: %[[RESULT:.*]] = "tf.ConcatV2"(%[[ITEMS]]#1, %[[AXIS]]) : (tensor<2xf32>, tensor) -> tensor<1x2xf32> // CHECK: return %[[RESULT]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir new file mode 100644 index 00000000000..2d86889e35b --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir @@ -0,0 +1,257 @@ +// RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s + +// CHECK-LABEL: func @unsupported_op +func @unsupported_op() -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.UnsupportedOp" + // CHECK-SAME: _xla_outside_compilation + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.UnsupportedOp"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %2 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @tf2xla_fallback_op +func @tf2xla_fallback_op() -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.UnsupportedOp" + // CHECK-SAME: _xla_outside_compilation + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.Sinh" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.UnsupportedOp"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %3 = "tf.Identity"(%1) : (tensor) -> tensor + %4 = "tf.Sinh"(%2) : (tensor) -> tensor + tf_device.return %4 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ignore_embedding_ops +func @ignore_embedding_ops() -> () { + "tf_device.cluster"() ( { + // CHECK: "tf.RecvTPUEmbeddingActivations" + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.SendTPUEmbeddingGradients" + // CHECK-NOT: _xla_outside_compilation + %2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) + "tf.SendTPUEmbeddingGradients"(%2#0, %2#1) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> () + tf_device.return + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> () + return +} + +// CHECK-LABEL: func @op_string_result +func @op_string_result() -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.Const" + // CHECK-SAME: _xla_outside_compilation + // CHECK-SAME: tf.string + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Const"() {value = dense<"x"> : tensor} : () -> tensor + %3 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %3 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @op_string_operand +func @op_string_operand(%arg0: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.StringToNumber" + // CHECK-SAME: _xla_outside_compilation + // CHECK-SAME: tf.string + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.StringToNumber"(%arg0) {out_type = f32} : (tensor) -> tensor + %3 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %3 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @op_string_operand_string_result +func @op_string_operand_string_result(%arg0: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.Identity" + // CHECK-SAME: _xla_outside_compilation + // CHECK-SAME: tf.string + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.Identity"(%arg0) : (tensor) -> tensor + %3 = "tf.Identity"(%1) : (tensor) -> tensor + tf_device.return %3 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// Test that a tf.IfRegion op with a captured string operand is marked for outside compilation. + +// CHECK-LABEL: func @if_region_captured_string +func @if_region_captured_string(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.IfRegion" + // CHECK: "tf.StringToNumber" + // CHECK: _xla_outside_compilation = "auto", is_stateless = true + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.IfRegion"(%arg0) ( { + %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor) -> tensor + "tf.Yield"(%3) : (tensor) -> () + }, { + %4 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + "tf.Yield"(%4) : (tensor) -> () + }) {is_stateless = true} : (tensor) -> (tensor) + %5 = "tf.Identity"(%2) : (tensor) -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// Test that ops with string results/operands inside a tf.IfRegion branch are marked for outside compilation. + +// CHECK-LABEL: func @if_region_string_op +func @if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.IfRegion" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.IfRegion"(%arg0) ( { + %3 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + "tf.Yield"(%3) : (tensor) -> () + }, { + // CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor} + // CHECK-NEXT: "tf.StringToNumber" + // CHECK-SAME: _xla_outside_compilation + %4 = "tf.Const"() {value = dense<"1.0"> : tensor} : () -> tensor + %5 = "tf.StringToNumber"(%4) {out_type = f32} : (tensor) -> tensor + "tf.Yield"(%5) : (tensor) -> () + // CHECK: {is_stateless + }) {is_stateless = true} : (tensor) -> (tensor) + %6 = "tf.Identity"(%2) : (tensor) -> tensor + tf_device.return %6: tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// Test that ops with string results/operands inside a nested tf.IfRegion branch are marked for outside compilation. + +// CHECK-LABEL: func @nested_if_region_string_op +func @nested_if_region_string_op(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.IfRegion" + // CHECK-NOT: _xla_outside_compilation + %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %2 = "tf.IfRegion"(%arg0) ( { + %3 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + "tf.Yield"(%3) : (tensor) -> () + }, { + // CHECK: "tf.Const"() {value = dense : tensor} + // CHECK-NOT: _xla_outside_compilation + %4 = "tf.Const"() {value = dense : tensor} : () -> tensor + %5 = "tf.IfRegion"(%4)({ + // CHECK: "tf.Const"() {_xla_outside_compilation = "auto", value = dense<"1.0"> : tensor} + // CHECK-NEXT: "tf.StringToNumber" + // CHECK-SAME: _xla_outside_compilation + %6 = "tf.Const"() {value = dense<"1.0"> : tensor} : () -> tensor + %7 = "tf.StringToNumber"(%6) {out_type = f32} : (tensor) -> tensor + "tf.Yield"(%7) : (tensor) -> () + }, { + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK-NOT: _xla_outside_compilation + %8 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + "tf.Yield"(%8) : (tensor) -> () + // CHECK: {is_stateless + }){is_stateless = true} : (tensor) -> (tensor) + "tf.Yield"(%5) : (tensor) -> () + // CHECK: {is_stateless + }) {is_stateless = true} : (tensor) -> (tensor) + %9 = "tf.Identity"(%2) : (tensor) -> tensor + tf_device.return %9: tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// Test that a tf.WhileRegion op with a captured string operand is marked for outside compilation. + +// CHECK-LABEL: func @while_region_captured_string +func @while_region_captured_string(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.WhileRegion" + // CHECK: "tf.StringToNumber" + // CHECK: _xla_outside_compilation = "auto", is_stateless = true + %1 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %2:2 = "tf.WhileRegion"(%1, %arg0) ( { + ^bb0(%carg0: tensor, %carg1: tensor): + %limit = constant dense<5> : tensor + %cond = "tf.NotEqual"(%carg1, %limit) : (tensor, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + ^bb0(%barg0: tensor, %barg1: tensor): + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + %3 = "tf.StringToNumber"(%arg1) {out_type = f32} : (tensor) -> tensor + "tf.Yield"(%3, %sub) : (tensor, tensor) -> () + }) {is_stateless = true} : (tensor, tensor) -> (tensor, tensor) + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %5 = "tf.Identity"(%2#0) : (tensor) -> (tensor) + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} + +// Test that an unsupported op within a tf.WhileRegion is marked for outside compilation. + +// CHECK-LABEL: func @while_region_unsupported_op +func @while_region_unsupported_op(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf_device.cluster"() ( { + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + // CHECK-NOT: _xla_outside_compilation + // CHECK: "tf.WhileRegion" + %1 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + %2:2 = "tf.WhileRegion"(%1, %arg0) ( { + ^bb0(%carg0: tensor, %carg1: tensor): + %limit = constant dense<5> : tensor + %cond = "tf.NotEqual"(%carg1, %limit) : (tensor, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + ^bb0(%barg0: tensor, %barg1: tensor): + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + // CHECK: "tf.UnsupportedOp" + // CHECK-SAME: _xla_outside_compilation + %3 = "tf.UnsupportedOp"() {value = dense<1> : tensor} : () -> tensor + // CHECK: "tf.Const"() {value = dense<1.000000e+00> : tensor} + %4 = "tf.Const"() {value = dense<1.0> : tensor} : () -> tensor + "tf.Yield"(%4, %sub) : (tensor, tensor) -> () + // CHECK: {is_stateless = true + }) {is_stateless = true} : (tensor, tensor) -> (tensor, tensor) + // CHECK: "tf.Identity" + // CHECK-NOT: _xla_outside_compilation + %5 = "tf.Identity"(%2#0) : (tensor) -> (tensor) + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/case.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/case.mlir new file mode 100644 index 00000000000..2f2ee6f1286 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/case.mlir @@ -0,0 +1,38 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 486 : i32}} { + func @main() { + tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %outputs_2, %control_3 = tf_executor.island wraps "tf.Case"(%outputs_0, %outputs) {Tin = [i32], Tout = [i32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], branches = [@indexed_case_branch0_40, @indexed_case_branch1_50], device = "", is_stateless = true, output_shapes = [#tf.shape<>]} : (tensor, tensor) -> tensor<*xi32> loc("stateless_case") + %outputs_4, %control_5 = tf_executor.island wraps "tf.Identity"(%outputs_2) {device = ""} : (tensor<*xi32>) -> tensor<*xi32> + %outputs_6, %control_7 = tf_executor.island wraps "tf.Case"(%outputs_0, %outputs) {Tin = [i32], Tout = [i32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], branches = [@indexed_case_branch0_40, @indexed_case_branch1_50], device = "", is_stateless = false, output_shapes = [#tf.shape<>]} : (tensor, tensor) -> tensor<*xi32> loc("regular_case") + tf_executor.fetch + } + return + } + + func @indexed_case_branch0_40(%arg0: tensor) -> tensor<*xi32> attributes {sym_visibility = "private"} { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs) {device = ""} : (tensor, tensor) -> tensor<*xi32> + tf_executor.fetch %outputs_0 : tensor<*xi32> + } + return %0 : tensor<*xi32> + } + + func @indexed_case_branch1_50(%arg0: tensor) -> tensor<*xi32> attributes {sym_visibility = "private"} { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", value = dense<2> : tensor} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.AddV2"(%arg0, %outputs) {device = ""} : (tensor, tensor) -> tensor<*xi32> + tf_executor.fetch %outputs_0 : tensor<*xi32> + } + return %0 : tensor<*xi32> + } +} + +// CHECK: name: "stateless_case" +// CHECK-NEXT: "StatelessCase" +// CHECK: name: "regular_case" +// CHECK-NEXT: "Case" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_attr.mlir new file mode 100644 index 00000000000..fadb62c44b8 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_attr.mlir @@ -0,0 +1,40 @@ +// RUN: tf-mlir-translate -mlir-to-graphdef %s | tf-mlir-translate -graphdef-to-mlir | tf-mlir-translate -mlir-to-graphdef | FileCheck %s + +// Tests #tf.func attributes are exported as AttrValue.NameAttrList attributes +// with its attr field populated with nested attributes. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 458 : i32}} { + func @main() { + tf_executor.graph { + %control = tf_executor.island wraps "tf.NoOp"() {_f = #tf.func<@callee, {attr2 = true, attr3 = 8.0 : f32}>} : () -> () + tf_executor.fetch + } + return + } + func @callee() { + tf_executor.graph { + tf_executor.fetch + } + return + } +} + +// CHECK: op: "NoOp" +// CHECK-NEXT: attr +// CHECK-NEXT: key: "_f" +// CHECK-NEXT: value +// CHECK-NEXT: func +// CHECK-NEXT: name: [[FUNC_NAME:".*"]] +// CHECK-NEXT: attr +// CHECK-NEXT: key: "attr2" +// CHECK-NEXT: value +// CHECK-NEXT: b: true +// CHECK: attr +// CHECK-NEXT: key: "attr3" +// CHECK-NEXT: value +// CHECK-NEXT: f: 8 + +// CHECK: library +// CHECK-NEXT: function +// CHECK-NEXT: signature +// CHECK-NEXT: name: [[FUNC_NAME]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir index d9ad36f2ce6..b6933459382 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-if-ops.mlir @@ -1,13 +1,13 @@ // RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s -func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { +func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor<4xf32>, %arg3: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { %0:2 = tf_executor.graph { %outputs_2, %control_3 = tf_executor.island wraps "tf.Less"(%arg0, %arg1) : (tensor, tensor) -> tensor - %outputs_4, %control_5 = tf_executor.island wraps "tf.If"(%outputs_2, %arg0, %arg1) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor, tensor, tensor) -> tensor loc("StatefulIf") - %outputs_6, %control_7 = tf_executor.island wraps "tf.If"(%outputs_2, %arg0, %arg1) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor, tensor, tensor) -> tensor loc("StatelessIf") - tf_executor.fetch %outputs_4, %outputs_6 : tensor, tensor + %outputs_4, %control_5 = tf_executor.island wraps "tf.If"(%outputs_2, %arg2, %arg3) {else_branch = @cond_false, is_stateless = false, then_branch = @cond_true} : (tensor, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("StatefulIf") + %outputs_6, %control_7 = tf_executor.island wraps "tf.If"(%outputs_2, %arg2, %arg3) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("StatelessIf") + tf_executor.fetch %outputs_4, %outputs_6 : tensor<4xf32>, tensor<4xf32> } - return %0#0, %0#1 : tensor, tensor + return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> } func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { @@ -34,8 +34,32 @@ func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { // CHECK-NOT: name: // CHECK: op: "If" // CHECK-NOT: is_stateless +// CHECK: attr { +// CHECK: key: "output_shapes" +// CHECK: value { +// CHECK: list { +// CHECK: shape { +// CHECK: dim { +// CHECK: size: 4 +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } // CHECK: name: "StatelessIf" // CHECK-NOT: name: // CHECK: op: "StatelessIf" // CHECK-NOT: is_stateless +// CHECK: attr { +// CHECK: key: "output_shapes" +// CHECK: value { +// CHECK: list { +// CHECK: shape { +// CHECK: dim { +// CHECK: size: 4 +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir index 9f14a144d9d..c7a4630d985 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/functional-while-ops.mlir @@ -1,12 +1,12 @@ // RUN: tf-mlir-translate -mlir-to-graphdef %s -o - | FileCheck %s -func @main(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { +func @main(%arg0: tensor, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { %0:2 = tf_executor.graph { - %outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor, tensor) -> (tensor, tensor) loc("StatefulWhile") - %outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor, tensor) -> (tensor, tensor) loc("StatelessWhile") - tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor, tensor + %outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor, tensor<5xf32>) -> (tensor, tensor<5xf32>) loc("StatefulWhile") + %outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor, tensor<5xf32>) -> (tensor, tensor<5xf32>) loc("StatelessWhile") + tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32> } - return %0#0, %0#1 : tensor, tensor + return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32> } func @cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> tensor { @@ -36,8 +36,34 @@ func @body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>) -> (tensor<*xi32>, tensor // CHECK-NOT: name: // CHECK: op: "While" // CHECK-NOT: is_stateless +// CHECK: attr { +// CHECK: key: "output_shapes" +// CHECK: value { +// CHECK: list { +// CHECK: shape { +// CHECK: dim { +// CHECK: size: 5 +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } + // CHECK: name: "StatelessWhile" // CHECK-NOT: name: // CHECK: op: "StatelessWhile" // CHECK-NOT: is_stateless +// CHECK: attr { +// CHECK: key: "output_shapes" +// CHECK: value { +// CHECK: list { +// CHECK: shape { +// CHECK: dim { +// CHECK: size: 5 +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir index 5f92d789066..3e50aa18098 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/tf-legacy-call.mlir @@ -3,7 +3,7 @@ func @main() { tf_executor.graph { %outputs, %control = tf_executor.island wraps "tf.Const"() {device = "", dtype = "tfdtype$DT_INT32", name = "Constant", value = dense<0> : tensor} : () -> tensor - %outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs) {f = @foo0} : (tensor) -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.LegacyCall"(%outputs) {_tpu_replicate = "cluster", device = "", f = @foo0} : (tensor) -> tensor tf_executor.fetch } return @@ -23,6 +23,12 @@ func @foo0(%arg0: tensor<*xi32>) -> tensor<*xi32> { // CHECK-NEXT: value { // CHECK-NEXT: list { // CHECK-NEXT: shape { +// CHECK: attr { +// CHECK-NEXT: key: "_tpu_replicate" +// CHECK-NEXT: value { +// CHECK-NEXT: s: "cluster" +// CHECK-NEXT: } +// CHECK-NEXT: } // CHECK: library { // CHECK-NEXT: function { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir index 31ca7b28fe7..52dc06cd393 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/parallel_execute_to_islands.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tf-parallel-execute-to-islands | FileCheck %s +// RUN: tf-opt %s -tf-parallel-execute-to-islands | FILECHECK_OPTS="" FileCheck %s // CHECK-LABEL: func @check_regions_to_islands func @check_regions_to_islands() { @@ -17,11 +17,9 @@ func @check_regions_to_islands() { return } -// CHECK: %[[ISLAND_INPUT_CTL:[a-z_0-9]*]] = tf_executor.island { -// CHECK-NEXT: tf_executor.yield -// CHECK: %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) { +// CHECK: %[[ISLAND_1_CTL:[a-z_0-9]*]] = tf_executor.island { // CHECK: tf_executor.yield -// CHECK: %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island(%[[ISLAND_INPUT_CTL]]) { +// CHECK: %[[ISLAND_2_CTL:[a-z_0-9]*]] = tf_executor.island { // CHECK: tf_executor.yield // CHECK: %{{.*}} = tf_executor.island(%[[ISLAND_1_CTL]], %[[ISLAND_2_CTL]]) { // CHECK-NEXT: tf_executor.yield @@ -192,3 +190,37 @@ func @check_output_barrier_correctly_forwards_outputs(%arg0 : tensor) -> ten // CHECK: tf_executor.yield %[[OP_C_OUTPUT]] : tensor // CHECK: %[[OUTPUT_SINK_OUTPUT:[a-z_0-9]*]]:2, %[[OUTPUT_SINK_CTL:[a-z_0-9]*]] = tf_executor.island { // CHECK-NEXT: tf_executor.yield %[[ISLAND_1_OUTPUT]], %[[ISLAND_2_OUTPUT]] : tensor, tensor + +// CHECK-LABEL: func @check_parallel_execute_using_args +// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor) +func @check_parallel_execute_using_args(%arg0 : tensor) { + tf_executor.graph { + %1:2 = tf_executor.island { + %2 = "tf.opA"(%arg0) : (tensor) -> tensor + tf_executor.yield %2 : tensor + } + %2:2 = tf_executor.island { + %3 = "tf.opB"(%arg0) : (tensor) -> tensor + tf_executor.yield %3 : tensor + } + tf_executor.island() { + "tf_device.parallel_execute"() ({ + %4 = "tf.opC"(%arg0, %1#0) : (tensor, tensor) -> tensor + tf_device.return %4 : tensor + }, + { + %5 = "tf.opD"(%arg0, %2#0) : (tensor, tensor) -> tensor + tf_device.return %5 : tensor + }) {} : () -> (tensor, tensor) + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// Verify that args are directly accessed in newly created island without alias +// through entry barrier. + +// CHECK: "tf.opC"(%[[ARG_0]] +// CHECK: "tf.opD"(%[[ARG_0]] diff --git a/tensorflow/compiler/mlir/tensorflow/tests/parallelize_embedding_params_ops_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/parallelize_embedding_params_ops_pass.mlir new file mode 100644 index 00000000000..e1cfaba5dcc --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/parallelize_embedding_params_ops_pass.mlir @@ -0,0 +1,96 @@ +// RUN: tf-opt %s -tf-parallize-embedding-params-ops -verify-diagnostics -split-input-file | FileCheck %s + +// CHECK-LABEL: func @two_shards +func @two_shards(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor<*x!tf.resource>>, %arg3: tensor<*x!tf.resource>>) { + tf_executor.graph { + %control = tf_executor.island { + // CHECK: "tf_device.parallel_execute" + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.LoadTPUEmbeddingAdagradParameters" + // CHECK: tf_device.return + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.ReadVariableOp" + // CHECK: "tf.LoadTPUEmbeddingAdagradParameters" + // CHECK: tf_device.return + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %2 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %3 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + "tf.LoadTPUEmbeddingAdagradParameters"(%2, %3) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %control : !tf_executor.control + } + return +} + +// Verifies that resource reads shared across two shards are kept outside the +// parallel_execute op. + +// CHECK-LABEL: func @shared_reads +func @shared_reads(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) { + tf_executor.graph { + %control = tf_executor.island { + // CHECK: "tf.ReadVariableOp" + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + // CHECK: "tf.ReadVariableOp" + %1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + + // CHECK: "tf_device.parallel_execute" + // CHECK: "tf.LoadTPUEmbeddingAdagradParameters" + // CHECK: tf_device.return + // CHECK: "tf.LoadTPUEmbeddingAdagradParameters" + // CHECK: tf_device.return + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %control : !tf_executor.control + } + return +} + +// Verifies that if the resource variables are used in ops other than read +// variable op whose semantics are not known then the function is kept +// unchanged. + +// CHECK-LABEL: func @update_var +func @update_var(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>, %arg2: tensor<*x!tf.resource>>) { + tf_executor.graph { + // CHECK-NOT: tf_device.parallel_execute + %control = tf_executor.island { + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 2 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + + %2 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %zeros = "tf.Const"() {value = dense<1.0> : tensor<8xf32>} : () -> tensor<8xf32> + "tf.AssignVariableOp"(%arg2, %zeros) : (tensor<*x!tf.resource>>, tensor<8xf32>) -> () + %3 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + "tf.LoadTPUEmbeddingAdagradParameters"(%2, %3) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 2 : i64, shard_id = 1 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %control : !tf_executor.control + } + return +} + +// ----- + +func @invalid_shard_range(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) { + tf_executor.graph { + %control = tf_executor.island { + // expected-error @-1 {{require continuous range of shards}} + %0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + %1 = "tf.ReadVariableOp"(%arg1) {device = ""} : (tensor<*x!tf.resource>>) -> tensor<8xf32> + + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:0/device:CPU:0", num_shards = 3 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + "tf.LoadTPUEmbeddingAdagradParameters"(%0, %1) {config = "", device = "/job:worker/replica:0/task:1/device:CPU:0", num_shards = 3 : i64, shard_id = 3 : i64, table_id = -1 : i64, table_name = "param_table"} : (tensor<8xf32>, tensor<8xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %control : !tf_executor.control + } + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir index 40cfc03b8e6..3e6d4f37bac 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/promote_resources_to_args.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-resources-to-args | FileCheck %s +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-promote-resources-to-args | FILECHECK_OPTS="" FileCheck %s // One resource, one read. The initial value of the resource is read. // CHECK-LABEL: func @main(%arg0: tensor, %arg1: tensor {tf.resource_name = "x"}) -> tensor<2xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir b/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir index 2b8f47a407e..7d36e6f4319 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/readonly_references_to_resources.mlir @@ -12,6 +12,18 @@ func @f() { // ----- +// Test case: Basic converting. '_class' attribute is at IdentityOp. + +func @f() { + // CHECK: "tf.VarHandleOp" + // CHECK: "tf.ReadVariableOp" + %val0 = "tf.VariableV2"() {container = "", device = "", shape = #tf.shape<96>, shared_name = ""} : () -> tensor<96x!tf.f32ref> + %val1 = "tf.Identity"(%val0) {_class = ["loc:@v"]} : (tensor<96x!tf.f32ref>) -> tensor<96xf32> + return +} + +// ----- + // Test case: Two ReadVariable ops. func @f() { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir index 5ea863852ad..e9d4e441a10 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir @@ -1,19 +1,23 @@ -// RUN: tf-opt %s -tf-region-control-flow-to-functional -split-input-file -//| FileCheck %s --dump-input=fail +// RUN: tf-opt %s -tf-region-control-flow-to-functional -split-input-file | FileCheck %s +// Simple IfRegion // CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32> // CHECK-NEXT: "tf.Neg" // CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32> // CHECK-NEXT: "tf.Abs" func @testSimple(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { - // CHECK: "tf.If"{{.+}}else_branch = @tf.IfRegion_else{{.+}}then_branch = @tf.IfRegion_then + // CHECK: "tf.If" + // CHECK-SAME: _attr0 = false + // CHECK-NOT: attr1 + // CHECK-SAME: else_branch = @tf.IfRegion_else + // CHECK-SAME: then_branch = @tf.IfRegion_then %0 = "tf.IfRegion"(%arg0) ({ %1 = "tf.Abs"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> "tf.Yield"(%1) : (tensor<*xf32>) -> () }, { %2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> "tf.Yield"(%2) : (tensor<*xf32>) -> () - }) { is_stateless = true } : (tensor) -> tensor<*xf32> + }) {is_stateless = true, _attr0 = false, attr1 = "hello"} : (tensor) -> tensor<*xf32> return %0 : tensor<*xf32> } @@ -42,7 +46,7 @@ func @testIfCondition(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> // ----- -// Constant sinking +// Constant sinking for IfRegion // CHECK: func @tf.IfRegion_else() -> tensor<2xf32> // CHECK-NEXT: constant dense<1.0 @@ -105,7 +109,7 @@ func @testNested(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { // ----- -// Match existing function->Region pattern (simple) +// Match existing function->Region pattern (simple) for IfRegion func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { @@ -122,7 +126,7 @@ func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { // ----- -// Match existing function->Region pattern (with casts) +// Match existing function->Region pattern (with casts) for IfRegion func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> @@ -142,7 +146,29 @@ func @testIf2Result(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { // ----- -// No inputs, some outputs +// Match existing function->Region pattern (with multiple casts) for IfRegion + +func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> +func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> +func @testIf2Result(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: "tf.If"({{.+}}) {else_branch = @testIf1Else, {{.+}} then_branch = @testIf1Then} + %0 = "tf.IfRegion"(%arg0) ( { + %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor) -> tensor<*xf32> + %3 = call @testIf1Then(%2) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%3) : (tensor<*xf32>) -> () + }, { + %1 = "tf.Cast"(%arg1) {Truncate = false} : (tensor<2xf32>) -> tensor + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor) -> tensor<*xf32> + %3 = call @testIf1Else(%2) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%3) : (tensor<*xf32>) -> () + }) {is_stateless = false} : (tensor) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +// No inputs, some outputs for IfRegion // CHECK: func @tf.IfRegion_else() -> tensor<2xf32> // CHECK-NEXT: constant dense<1.000000e+00> // CHECK-NEXT: "tf.Neg" @@ -165,7 +191,7 @@ func @testSimple(%arg0: tensor) -> tensor<2xf32> { // ----- -// No outputs, some inputs +// No outputs, some inputs for IfRegion // // CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) // CHECK-NEXT: "tf.Neg" @@ -186,3 +212,383 @@ func @testNoOutputs(%arg0: tensor, %arg1: tensor<*xf32>) -> () { return } +// ----- + +// Simple WhileRegion +// CHECK: func @tf.WhileRegion_body{{.+}}{sym_visibility = "private"} +// CHECK: "tf.Add" +// CHECK: constant dense<1> +// CHECK: "tf.Sub" +// CHECK:func @tf.WhileRegion_cond{{.+}}{sym_visibility = "private"} +// CHECK: constant dense<0> +// CHECK: "tf.NotEqual" +// CHECK-LABEL: testValidWhileRegion +func @testValidWhileRegion(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) + // CHECK-SAME: _attr0 = false + // CHECK-NOT: attr1 + // CHECK-SAME: body = @tf.WhileRegion_body + // CHECK-SAME: cond = @tf.WhileRegion_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + // condition, check if count has reached 0 + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %zero = constant dense<0> : tensor + %ne = "tf.NotEqual"(%carg1, %zero) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %add = "tf.Add"(%barg0, %barg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + "tf.Yield"(%add, %sub) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false, _attr0 = false, attr1 = "hello"} : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// WhileRegion with type mismatch +// CHECK: func @tf.WhileRegion_body{{.+}}{sym_visibility = "private"} +// CHECK: "tf.Add" +// CHECK: constant dense<1> +// CHECK: "tf.Sub" +// CHECK:func @tf.WhileRegion_cond{{.+}}{sym_visibility = "private"} +// CHECK: constant dense<0> +// CHECK: "tf.NotEqual" +// CHECK-LABEL: testWhileRegionTypeMismatch +func @testWhileRegionTypeMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + // condition, check if count has reached 0 + ^bb0(%carg0: tensor<4xf32>, %carg1: tensor): + %zero = constant dense<0> : tensor + %ne = "tf.NotEqual"(%carg1, %zero) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<4xf32>, %barg1: tensor): + %add = "tf.Add"(%barg0, %barg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + "tf.Yield"(%add, %sub) : (tensor<4xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// WhileRegion with constant sinking +// CHECK: func @tf.WhileRegion_body{{.+}}{sym_visibility = "private"} +// CHECK: constant dense<1> +// CHECK: "tf.Add" +// CHECK: "tf.Sub" +// CHECK:func @tf.WhileRegion_cond{{.+}}{sym_visibility = "private"} +// CHECK: constant dense<0> +// CHECK: "tf.NotEqual" +// CHECK-LABEL: testWhileRegionConstantSink +func @testWhileRegionConstantSink(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + %zero = constant dense<0> : tensor + %one = constant dense<1> : tensor + // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<4xf32>, %carg1: tensor): + %ne = "tf.NotEqual"(%carg1, %zero) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<4xf32>, %barg1: tensor): + %add = "tf.Add"(%barg0, %barg0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + "tf.Yield"(%add, %sub) : (tensor<4xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// WhileRegion with implicitly captured extern value in cond +// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: tensor) +// CHECK: "tf.Add" +// CHECK: constant dense<1> +// CHECK: "tf.Sub" +// CHECK: return %{{.+}}, %{{.+}}, %arg2 : tensor<*xf32>, tensor, tensor +// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: tensor) +// CHECK: "tf.NotEqual"(%arg1, %arg2) +// CHECK-LABEL: testWhileRegionExternInCond +func @testWhileRegionExternInCond(%arg0 : tensor<*xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<*xf32> { + %cst = constant dense<4> : tensor + %limit = "tf.Add"(%arg2, %cst) : (tensor, tensor) -> tensor + // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}}) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %ne = "tf.NotEqual"(%carg1, %limit) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %add = "tf.Add"(%barg0, %barg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg1, %one) : (tensor, tensor) -> tensor + "tf.Yield"(%add, %sub) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// WhileRegion with implicitly captured extern value in body +// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: tensor) +// CHECK: %0 = "tf.Add"(%arg0, %arg0) +// CHECK: %1 = "tf.Sub"(%arg1, %arg2) +// CHECK: return %0, %1, %arg2 + +// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: tensor) +// CHECK: constant dense<0> +// CHECK: "tf.NotEqual" + +// CHECK-LABEL: testWhileRegionExternInBody +func @testWhileRegionExternInBody(%arg0 : tensor<*xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<*xf32> { + %zero = constant dense<0> : tensor + %cst = constant dense<4> : tensor + %stride = "tf.Add"(%arg2, %cst) : (tensor, tensor) -> tensor + // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}}) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %ne = "tf.NotEqual"(%carg1, %zero) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %add = "tf.Add"(%barg0, %barg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %sub = "tf.Sub"(%barg1, %stride) : (tensor, tensor) -> tensor + "tf.Yield"(%add, %sub) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// WhileRegion with implicitly captured extern value in cond and body +// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) +// CHECK: return %{{.+}}, %{{.+}}, %arg2, %arg3 +// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) +// CHECK-LABEL: testWhileRegionExternInBodyAndCond +func @testWhileRegionExternInBodyAndCond(%arg0 : tensor<*xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<*xf32> { + %cst = constant dense<4> : tensor + %stride = "tf.Add"(%arg2, %cst) : (tensor, tensor) -> tensor + %cst1 = constant dense<44> : tensor + %limit = "tf.Add"(%arg2, %cst1) : (tensor, tensor) -> tensor + // CHECK: [[Result:%.*]]:4 = "tf.While"(%arg0, %arg1, %{{.+}}, %{{.+}}) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %ne = "tf.NotEqual"(%carg1, %limit) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %add = "tf.Add"(%barg0, %barg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %sub = "tf.Sub"(%barg1, %stride) : (tensor, tensor) -> tensor + "tf.Yield"(%add, %sub) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// WhileRegion with same value implicitly captured in cond and body +// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: tensor) +// CHECK: return %{{.+}}, %{{.+}}, %arg2 +// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor, %arg2: tensor) +// CHECK-LABEL: testWhileRegionSameExternInBodyAndCond +func @testWhileRegionSameExternInBodyAndCond(%arg0 : tensor<*xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<*xf32> { + %cst = constant dense<4> : tensor + %stride = "tf.Add"(%arg2, %cst) : (tensor, tensor) -> tensor + // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}}) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %ne = "tf.NotEqual"(%carg1, %stride) : (tensor, tensor) -> tensor + "tf.Yield"(%ne) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %add = "tf.Add"(%barg0, %barg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %sub = "tf.Sub"(%barg1, %stride) : (tensor, tensor) -> tensor + "tf.Yield"(%add, %sub) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// Simple trivially transformable while +// CHECK: func @while_cond +// CHECK: func @while_body +// CHECK-LABEL: testWhileRegionTrivial +func @while_cond(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor +func @while_body(%arg0 : tensor<*xf32>, %arg1 : tensor) -> (tensor<*xf32>, tensor) +func @testWhileRegionTrivial(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @while_body, cond = @while_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %cond = call @while_cond(%carg0, %carg1) : (tensor<*xf32>, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %bdy:2 = call @while_body(%barg0, %barg1) : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + "tf.Yield"(%bdy#0, %bdy#1) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// Trivially transformable with casts +// CHECK: func @while_cond +// CHECK: func @while_body +// CHECK-LABEL: testWhileRegionTrivialCasts +func @while_cond(%arg0 : tensor<4xf32>, %arg1 : tensor) -> tensor +func @while_body(%arg0 : tensor<4xf32>, %arg1 : tensor) -> (tensor<4xf32>, tensor) +func @testWhileRegionTrivialCasts(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @while_body, cond = @while_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %cond_cast = "tf.Cast"(%carg0) : (tensor<*xf32>) -> tensor<4xf32> + %cond = call @while_cond(%cond_cast, %carg1) : (tensor<4xf32>, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %bdy_cast = "tf.Cast"(%barg0) : (tensor<*xf32>) -> tensor<4xf32> + %bdy:2 = call @while_body(%bdy_cast, %barg1) : (tensor<4xf32>, tensor) -> (tensor<4xf32>, tensor) + "tf.Yield"(%bdy#0, %bdy#1) : (tensor<4xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// Trivially transformable with multiple casts +// CHECK: func @while_cond +// CHECK: func @while_body +// CHECK-LABEL: testWhileRegionTrivialMultipleCasts +func @while_cond(%arg0 : tensor<4xf32>, %arg1 : tensor) -> tensor +func @while_body(%arg0 : tensor<4xf32>, %arg1 : tensor) -> (tensor<4xf32>, tensor) +func @testWhileRegionTrivialMultipleCasts(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @while_body, cond = @while_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %cond_cast0 = "tf.Cast"(%carg0) : (tensor<*xf32>) -> tensor + %cond_cast1 = "tf.Cast"(%cond_cast0) : (tensor) -> tensor<4xf32> + %cond = call @while_cond(%cond_cast1, %carg1) : (tensor<4xf32>, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %bdy_cast0 = "tf.Cast"(%barg0) : (tensor<*xf32>) -> tensor + %bdy_cast1 = "tf.Cast"(%bdy_cast0) : (tensor) -> tensor<4xf32> + %bdy:2 = call @while_body(%bdy_cast1, %barg1) : (tensor<4xf32>, tensor) -> (tensor<4xf32>, tensor) + "tf.Yield"(%bdy#0, %bdy#1) : (tensor<4xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// Almost trivially transformable with extern values +// CHECK: func @tf.WhileRegion_body +// CHECK: call @while_body +// CHECK: @tf.WhileRegion_cond +// CHECK: call @while_cond +// CHECK-LABEL: testWhileRegionExtern +func @while_cond(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor +func @while_body(%arg0 : tensor<*xf32>, %arg1 : tensor, %arg2 : tensor<*xf32>) -> (tensor<*xf32>, tensor) +func @testWhileRegionExtern(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + %ext = "tf.Neg"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: [[Result:%.*]]:3 = "tf.While"(%arg0, %arg1, %{{.+}}) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %cond = call @while_cond(%carg0, %carg1) : (tensor<*xf32>, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %bdy:2 = call @while_body(%barg0, %barg1, %ext) : (tensor<*xf32>, tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor) + "tf.Yield"(%bdy#0, %bdy#1) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// Almost trivially transformable, mismatching block arguments +// CHECK: func @tf.WhileRegion_body +// CHECK: call @while_body +// CHECK: @tf.WhileRegion_cond +// CHECK: call @while_cond +// CHECK-LABEL: testWhileRegionBlockArgMismatch +func @while_cond(%arg0 : tensor, %arg1 : tensor<*xf32>) -> tensor +func @while_body(%arg0 : tensor<*xf32>, %arg1 : tensor) -> (tensor<*xf32>, tensor) +func @testWhileRegionBlockArgMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @tf.WhileRegion_body, cond = @tf.WhileRegion_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %cond = call @while_cond(%carg1, %carg0) : (tensor, tensor<*xf32>) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %bdy:2 = call @while_body(%barg0, %barg1) : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + "tf.Yield"(%bdy#0, %bdy#1) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir index 9931a45f995..487234ce958 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/replicate_to_island.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tf-replicate-to-island | FileCheck %s +// RUN: tf-opt -split-input-file %s -tf-replicate-to-island | FileCheck %s // Tests per replica island has same control operands as island holding // replicate. @@ -223,3 +223,219 @@ func @replica_id_attr_added(%arg0: tensor, %arg1: tensor // CHECK: "tf.A" // CHECK-NOT: _xla_replica_id // CHECK: tf_executor.fetch + + +// Tests device ordinals are added to `tf._XlaSendFromHost`/`tf._XlaRecvAtHost` +// based on the first TPU core device id. +// CHECK-LABEL: func @device_ordinals +func @device_ordinals(%arg0: tensor, %arg1: tensor<2x!tf.string>) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg0] as %arg2: tensor) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + %0 = "tf._XlaRecvAtHost"(%arg1) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor + "tf._XlaSendFromHost"(%0, %arg1) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor, tensor<2x!tf.string>) -> () + "tf.NoOp"() : () -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: tf_executor.island +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 1 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 1 +// CHECK: "tf.NoOp" +// CHECK: tf_executor.island +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 2 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 2 +// CHECK: "tf.NoOp" + +// ----- + +// Tests functions with replica variant ops reachable from a replicate region +// is cloned and remapped. + +// CHECK-LABEL: func @call_with_replicate_variant_ops +func @call_with_replicate_variant_ops(%arg0: tensor, %arg1: tensor<2x!tf.string>) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg0] as %arg2: tensor) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + "tf.StatefulPartitionedCall"(%arg1) {config = "", config_proto = "", executor_type = "", f = @send_recv} : (tensor<2x!tf.string>) -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[CALL_REPLICA_0:@[a-z0-9_]+]] +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[CALL_REPLICA_1:@[a-z0-9_]+]] + +func @send_recv(%arg0: tensor<2x!tf.string>) { + %0 = "tf._XlaRecvAtHost"(%arg0) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor + "tf._XlaSendFromHost"(%0, %arg0) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor, tensor<2x!tf.string>) -> () + "tf.NoOp"() : () -> () + return +} + +// CHECK: func [[CALL_REPLICA_0]] +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 1 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 1 + +// CHECK: func [[CALL_REPLICA_1]] +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 2 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 2 + +// ----- + +// Tests transitive functions with replica variant ops reachable from a +// replicate region is cloned and remapped. + +// CHECK-LABEL: func @call_with_replicate_variant_ops +func @call_with_replicate_variant_ops(%arg0: tensor, %arg1: tensor<2x!tf.string>) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg0] as %arg2: tensor) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + "tf.StatefulPartitionedCall"(%arg1) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor<2x!tf.string>) -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[CALLEE_REPLICA_0:@[a-z0-9_]+]] +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[CALLEE_REPLICA_1:@[a-z0-9_]+]] + +func @callee(%arg0: tensor<2x!tf.string>) { + "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @send_recv} : (tensor<2x!tf.string>) -> () + return +} + +func @send_recv(%arg0: tensor<2x!tf.string>) { + %0 = "tf._XlaRecvAtHost"(%arg0) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor + "tf._XlaSendFromHost"(%0, %arg0) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor, tensor<2x!tf.string>) -> () + "tf.NoOp"() : () -> () + return +} + +// CHECK: func [[CALLEE_REPLICA_0]] +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[TRANSITIVE_CALLEE_REPLICA_0:@[a-z0-9_]+]] + +// CHECK: func [[TRANSITIVE_CALLEE_REPLICA_0]] +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 1 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 1 + +// CHECK: func [[CALLEE_REPLICA_1]] +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = [[TRANSITIVE_CALLEE_REPLICA_1:@[a-z0-9_]+]] + +// CHECK: func [[TRANSITIVE_CALLEE_REPLICA_1]] +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 2 +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 2 + +// ----- + +// Tests functional control flow functions with replica variant ops reachable +// from a replicate region is cloned and remapped. Only the branches reachable +// with replica variant ops are cloned. + +// CHECK-LABEL: func @control_flow_with_replicate_variant_ops +func @control_flow_with_replicate_variant_ops(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<2x!tf.string>) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg0] as %arg4: tensor, [%arg1, %arg1] as %arg5: tensor, [%arg2, %arg2] as %arg6: tensor) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + %0 = "tf.If"(%arg4, %arg5, %arg6, %arg3) {else_branch = @cond_false, is_stateless = true, then_branch = @cond_true} : (tensor, tensor, tensor, tensor<2x!tf.string>) -> tensor + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: "tf.If" +// CHECK-SAME: else_branch = @cond_false +// CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_0:@[a-z0-9_]+]] +// CHECK: "tf.If" +// CHECK-SAME: else_branch = @cond_false +// CHECK-SAME: then_branch = [[COND_TRUE_REPLICA_1:@[a-z0-9_]+]] + +func @cond_false(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x!tf.string>) -> tensor { + return %arg0 : tensor +} + +// CHECK-NOT: func @cond_false.+( + +func @cond_true(%arg0: tensor, %arg1: tensor, %arg2: tensor<2x!tf.string>) -> tensor { + "tf._XlaSendFromHost"(%arg1, %arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_recv_0"} : (tensor, tensor<2x!tf.string>) -> () + %0 = "tf._XlaRecvAtHost"(%arg2) {_xla_has_host_transfer = true, device_ordinal = 0 : i64, key = "host_compute_channel_send_0"} : (tensor<2x!tf.string>) -> tensor + return %0 : tensor +} + +// CHECK: func [[COND_TRUE_REPLICA_0]] +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 1 +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 1 + +// CHECK: func [[COND_TRUE_REPLICA_1]] +// CHECK: "tf._XlaSendFromHost" +// CHECK-SAME: device_ordinal = 2 +// CHECK: "tf._XlaRecvAtHost" +// CHECK-SAME: device_ordinal = 2 + +// ----- + +// Tests function with no replica variant ops reachable from a replicate region +// is not cloned. + +// CHECK-LABEL: func @no_replicate_variant_ops +func @no_replicate_variant_ops(%arg0: tensor, %arg1: tensor<2x!tf.string>) { + tf_executor.graph { + tf_executor.island { + tf_device.replicate([%arg0, %arg0] as %arg2: tensor) {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2"]}} { + "tf.StatefulPartitionedCall"(%arg1) {config = "", config_proto = "", executor_type = "", f = @send_recv} : (tensor<2x!tf.string>) -> () + tf_device.return + } + tf_executor.yield + } + tf_executor.fetch + } + return +} + +// CHECK: "tf.StatefulPartitionedCall" +// CHECK-SAME: f = @send_recv + +func @send_recv(%arg0: tensor<2x!tf.string>) { + "tf.NoOp"() : () -> () + return +} + +// CHECK-NOT: @send_recv.+( diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir new file mode 100644 index 00000000000..87da399b726 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir @@ -0,0 +1,234 @@ +// RUN: tf-opt -split-input-file -tf-test-resource-alias-analysis -verify-diagnostics %s | FileCheck %s + +// Test 2 resources that do not alias. + +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @non_aliasing_reads_writes +// expected-remark@below {{Region #0, Arg #0, ID 1 : 1}} +// expected-remark@below {{Region #0, Arg #1, ID 2 : 2}} +func @non_aliasing_reads_writes( + %arg0: !tf_res, + %arg1: !tf_res, + %arg2: tensor<32xf32>) -> (tensor<32xf32>) { + %graph = tf_executor.graph { + // CHECK: tf_executor.island + %island:2 = tf_executor.island { + %read0 = "tf.ReadVariableOp"(%arg0) : (!tf_res) -> tensor<32xf32> + "tf.AssignVariableOp"(%arg0, %arg2) : (!tf_res, tensor<32xf32>) -> () + %read1 = "tf.ReadVariableOp"(%arg1) : (!tf_res) -> tensor<32xf32> + // expected-remark@below {{Result #0, ID 0 : 0}} + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + %read2 = "tf.ReadVariableOp"(%var_handle) : (!tf_res) -> tensor<32xf32> + "tf.AssignVariableOp"(%arg1, %read0) : (!tf_res, tensor<32xf32>) -> () + "tf.AssignVariableOp"(%arg0, %read2) : (!tf_res, tensor<32xf32>) -> () + %read3 = "tf.ReadVariableOp"(%arg0) : (!tf_res) -> tensor<32xf32> + tf_executor.yield %read3 : tensor<32xf32> + } + tf_executor.fetch %island#0 : tensor<32xf32> + } + return %graph : tensor<32xf32> +} + +// ----- +// Tests aliasing of the two resource handles that refer to the same variable. + +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @aliasing_reads_writes +func @aliasing_reads_writes(%arg0: tensor<32xf32>) -> () { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 2}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : 0, 1, 2}} + %vh1 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 2 : 0, 1, 2}} + %vh1_id:2 = "tf.IdentityN"(%vh1, %arg0) : (!tf_res, tensor<32xf32>) -> (!tf_res, tensor<32xf32>) + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor<32xf32> + "tf.AssignVariableOp"(%vh1_id#0, %arg0) : (!tf_res, tensor<32xf32>) -> () + %read1 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor<32xf32> + %read2 = "tf.ReadVariableOp"(%vh1) : (!tf_res) -> tensor<32xf32> + "tf.AssignVariableOp"(%vh0, %read2) : (!tf_res, tensor<32xf32>) -> () + "tf.AssignVariableOp"(%vh1_id#0, %read1) : (!tf_res, tensor<32xf32>) -> () + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// ----- +// Test an unknown op that has a resource result is marked unknown + +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @unknown_resource_op +func @unknown_resource_op(%arg0: tensor<32xf32>) -> () { + // expected-remark@below {{Result #0, ID 0 : Unknown}} + %0 = "tf.UnknownVarHandleOp"() : () -> !tf_res +} + +// ----- +// Test aliasing through IfOp + +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @if_op_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 4 : 1, 4}} +// expected-remark@below {{Region #0, Arg #1, ID 5 : 1, 2, 3, 5}} +func @if_op_aliasing(%arg0: !tf_res, %arg1: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor<32xf32> + // expected-remark@below {{Result #0, ID 1 : Unknown}} + // expected-remark@below {{Result #1, ID 2 : 1, 2, 3, 5}} + // expected-remark@below {{Result #2, ID 3 : 0, 1, 2, 3, 5}} + %if:3 = "tf.If"(%read0, %arg1, %vh0) { + then_branch = @if_then, else_branch = @if_else, is_stateless = true + } : (tensor<32xf32>, !tf_res, !tf_res) -> (!tf_res, !tf_res, !tf_res) + return +} + +// expected-remark@below {{Region #0, Arg #0, ID 2 : 0, 1, 2}} +// expected-remark@below {{Region #0, Arg #1, ID 3 : 0, 3}} +func @if_then(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : 0, 1, 2}} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + return %u0, %id0, %id0 : !tf_res, !tf_res, !tf_res +} + +// expected-remark@below {{Region #0, Arg #0, ID 1 : 0, 1}} +// expected-remark@below {{Region #0, Arg #1, ID 2 : 2}} +func @if_else(%arg0: !tf_res, %arg1: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0, 1}} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + return %id0, %id0, %arg1 : !tf_res, !tf_res, !tf_res +} + +// ----- +// Test aliasing through WhileOp +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @while_op_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 4 : 1, 4}} +// expected-remark@below {{Region #0, Arg #1, ID 5 : 1, 2, 3, 5}} +// expected-remark@below {{Region #0, Arg #2, ID 6 : 1, 2, 3, 6}} +func @while_op_aliasing(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : Unknown}} + // expected-remark@below {{Result #1, ID 2 : 1, 2, 3, 5, 6}} + // expected-remark@below {{Result #2, ID 3 : 1, 2, 3, 5, 6}} + %w:3 = "tf.While"(%arg0, %arg1, %arg2) { + body = @while_body, cond = @while_cond, is_stateless = false + } : (!tf_res, !tf_res, !tf_res) -> (!tf_res, !tf_res, !tf_res) + return +} + +// CHECK-LABEL: func @while_body +// Return 0 : new unknown resource +// Return 1 : arg2 +// Return 2 : arg1 +// expected-remark@below {{Region #0, Arg #0, ID 1 : 0, 1}} +// expected-remark@below {{Region #0, Arg #1, ID 2 : 0, 2}} +// expected-remark@below {{Region #0, Arg #2, ID 3 : 0, 3}} +func @while_body(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) -> (!tf_res, !tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + return %u0, %arg2, %arg1 : !tf_res, !tf_res, !tf_res +} + +// CHECK-LABEL: func @while_cond +// expected-remark@below {{Region #0, Arg #0, ID 0 : 0}} +// expected-remark@below {{Region #0, Arg #1, ID 1 : 1}} +// expected-remark@below {{Region #0, Arg #2, ID 2 : 2}} +func @while_cond(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) -> tensor { + %0 = constant dense : tensor + return %0 : tensor +} + +// ----- +// Test alias propagation through calls. +!tf_res = type tensor<*x!tf.resource>> +// CHECK-LABEL: func @aliasing_through_calls +func @aliasing_through_calls(%arg0: tensor<32xf32>) -> () { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 2, 3}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : 0, 1, 2, 3}} + %vh1 = "tf.Identity"(%vh0) : (!tf_res) -> (!tf_res) + // expected-remark@below {{Result #0, ID 2 : Unknown}} + // expected-remark@below {{Result #1, ID 3 : 0, 1, 2, 3}} + %c:2 = call @passthru(%vh1) : (!tf_res) -> (!tf_res, !tf_res) + return +} + +// expected-remark@below {{Region #0, Arg #0, ID 1 : 1}} +func @passthru(%arg0: !tf_res) -> (!tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0}} + %vx = "tf.VarHandleOp"() {container = "cf", shared_name = "vx"} : () -> !tf_res + return %vx, %arg0 : !tf_res, !tf_res +} + +// ----- +// Test aliasing through IfRegion + +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @if_region_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 7 : 1, 4, 6, 7}} +// expected-remark@below {{Region #0, Arg #1, ID 8 : 1, 2, 4, 5, 6, 8}} +func @if_region_aliasing(%arg0: !tf_res, %arg1: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 3, 4, 5}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor<32xf32> + // expected-remark@below {{Result #0, ID 4 : Unknown}} + // expected-remark@below {{Result #1, ID 5 : 0, 1, 2, 3, 4, 5, 6, 8}} + // expected-remark@below {{Result #2, ID 6 : 1, 2, 4, 5, 6, 7, 8}} + %if:3 = "tf.IfRegion"(%read0) ({ + // expected-remark@below {{Result #0, ID 1 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + // expected-remark@below {{Result #0, ID 2 : 1, 2, 4, 5, 6, 8}} + %id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res + "tf.Yield"(%u0, %id0, %id0) : (!tf_res, !tf_res, !tf_res) -> () + }, { + // expected-remark@below {{Result #0, ID 3 : 0, 1, 3, 4, 5}} + %id0 = "tf.Identity"(%vh0) : (!tf_res) -> !tf_res + "tf.Yield"(%id0, %id0, %arg0) : (!tf_res, !tf_res, !tf_res) -> () + }) {is_stateless = true} : (tensor<32xf32>) -> (!tf_res, !tf_res, !tf_res) + return +} + +// ----- +// Test aliasing through WhileRegion +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @while_region_aliasing +// expected-remark@below {{Region #0, Arg #0, ID 11 : 1, 8, 11}} +// expected-remark@below {{Region #0, Arg #1, ID 12 : 1, 8, 9, 10, 12}} +// expected-remark@below {{Region #0, Arg #2, ID 13 : 1, 8, 9, 10, 13}} +func @while_region_aliasing(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 8}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 8 : Unknown}} + // expected-remark@below {{Result #1, ID 9 : 1, 8, 9, 10, 12, 13}} + // expected-remark@below {{Result #2, ID 10 : 1, 8, 9, 10, 12, 13}} + // expected-remark@below {{Region #0, Arg #0, ID 2 : 1, 2, 8}} + // expected-remark@below {{Region #0, Arg #1, ID 3 : 1, 3, 8}} + // expected-remark@below {{Region #0, Arg #2, ID 4 : 1, 4, 8}} + // expected-remark@below {{Region #1, Arg #0, ID 5 : 1, 5, 8}} + // expected-remark@below {{Region #1, Arg #1, ID 6 : 1, 6, 8}} + // expected-remark@below {{Region #1, Arg #2, ID 7 : 1, 7, 8}} + %w:3 = "tf.WhileRegion"(%arg0, %arg1, %arg2) ({ + ^bb0(%carg0: !tf_res, %carg1: !tf_res, %carg2: !tf_res): + %0 = constant dense : tensor + "tf.Yield"(%0) : (tensor) -> () + },{ + ^bb0(%barg0: !tf_res, %barg1: !tf_res, %barg2: !tf_res): + // expected-remark@below {{Result #0, ID 1 : Unknown}} + %u0 = "tf._UnknownSideEffectingOp_"() : () -> !tf_res + "tf.Yield"(%u0, %barg2, %barg1) : (!tf_res, !tf_res, !tf_res) -> () + }) {is_stateless = false} : (!tf_res, !tf_res, !tf_res) -> (!tf_res, !tf_res, !tf_res) + return +} + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir index a9e814c647e..a4a7c1dad2e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir @@ -56,7 +56,7 @@ func @propagate_if_op( "tf.If"(%arg1, %id0, %var_handle) { then_branch = @if_then, else_branch = @if_else, - output_shapes = [], is_stateless = false} + is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> () tf_executor.yield @@ -128,8 +128,7 @@ func @propagate_while_op( // CHECK-NEXT: "tf.While" "tf.While"(%arg1, %id0, %var_handle) { body = @while_body, - cond = @while_cond, - output_shapes = [], is_stateless = false} + cond = @while_cond, is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor, tensor<*x!tf.resource>>, @@ -209,8 +208,7 @@ func @error_on_conflict_multiple_callers( : () -> tensor<*x!tf.resource>> "tf.If"(%arg1, %id0, %var_handle) { then_branch = @if_then_and_else, - else_branch = @if_then_and_else, - output_shapes = [], is_stateless = false} + else_branch = @if_then_and_else, is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> () "tf.If"(%arg1, %var_handle, %id0) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index 7c8e4382e2b..ac5c2df8f7e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-resource-op-lifting | FileCheck %s +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-resource-op-lifting | FILECHECK_OPTS="" FileCheck %s // Tests that resource load operations are hoisted. @@ -147,8 +147,7 @@ func @cluster_with_loop() -> () { "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]]:2 = "tf.While"(%[[COUNT]], %[[READ]]) %2:3 = "tf.While"(%0, %1, %unused) - {body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = [#tf.shape<>, #tf.shape<>]} + {body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) // CHECK: tf_device.return %[[WHILE]]#1 : tensor @@ -197,8 +196,7 @@ func @cluster_with_loop() -> () { "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]]) %1 = "tf.While"(%0) { - body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = [#tf.shape<>]} + body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) // CHECK: tf_device.return %[[WHILE]] : tensor @@ -239,8 +237,7 @@ func @cluster_with_loop() -> () { "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]]) %1 = "tf.While"(%0) { - body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = [#tf.shape<>]} + body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) // CHECK: tf_device.return @@ -278,8 +275,7 @@ func @cluster_with_nested_loop() -> () { "tf_device.cluster"() ( { // CHECK: %[[WHILE:.*]] = "tf.While"(%[[READ]]) %2:2 = "tf.While"(%0, %1) { - body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = [#tf.shape<>, #tf.shape<>]} + body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) // CHECK: tf_device.return %[[WHILE]] : tensor @@ -295,8 +291,7 @@ func @while_body(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf -> (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) { // CHECK: %[[WHILE:.*]] = "tf.While"(%[[BARG0]]) %0:2 = "tf.While"(%arg0, %arg1) { - body = @while_body1, cond = @while_cond1, device = "", is_stateless = false, - output_shapes = [#tf.shape<>, #tf.shape<>]} + body = @while_body1, cond = @while_cond1, device = "", is_stateless = false} : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) // CHECK-NEXT: return %[[WHILE]] @@ -334,8 +329,7 @@ func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> "tf_device.cluster"() ( { %1 = "tf.While"(%0) { - body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = [#tf.shape<>]} + body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) tf_device.return }) {cluster_attr = "cluster_attr"} : () -> () @@ -359,8 +353,7 @@ func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> "tf_device.cluster"() ( { %1 = "tf.While"(%0) { - body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = [#tf.shape<>]} + body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) tf_device.return }) {cluster_attr = "cluster_attr"} : () -> () @@ -384,8 +377,7 @@ func @cluster_with_loop() -> () { %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> "tf_device.cluster"() ( { %1 = "tf.While"(%0) { - body = @while_body, cond = @while_cond, device = "", is_stateless = false, - output_shapes = [#tf.shape<>]} + body = @while_body, cond = @while_cond, device = "", is_stateless = false} : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) tf_device.return }) {cluster_attr = "cluster_attr"} : () -> () @@ -600,6 +592,35 @@ func @if_else(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf. // ----- +// Tests that the pass reports error if output does not alias input. + +func @cluster_with_if(%arg0: tensor) -> tensor<4xf32> { + %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> + %1 = "tf.VarHandleOp"() {container = "c", shared_name = "v2"} : () -> tensor<*x!tf.resource>> + %2 = "tf_device.cluster"() ( { + // expected-error @+1 {{unsupported output: resource does not alias input}} + %3 = "tf.If"(%arg0, %0, %1) {then_branch = @if_then, else_branch = @if_else, + is_stateless = false} + : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) + -> (tensor<*x!tf.resource>>) + %4 = "tf.ReadVariableOp"(%3) : (tensor<*x!tf.resource>>) -> tensor<4xf32> + tf_device.return %4 : tensor<4xf32> + }) {cluster_attr = "cluster_attr"} : () -> tensor<4xf32> + return %2 : tensor<4xf32> +} +func @if_then(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) + -> (tensor<*x!tf.resource>>) { + %0 = "tf.foo"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + return %0 : tensor<*x!tf.resource>> +} +func @if_else(%arg0: tensor<*x!tf.resource>>, %arg1: tensor<*x!tf.resource>>) + -> (tensor<*x!tf.resource>>) { + %0 = "tf.bar"(%arg0) : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + return %0 : tensor<*x!tf.resource>> +} + +// ----- + // Tests that the pass lifts resources on two partitioned call ops sharing the // same callee. The lifting should clone the callee then modify the clone. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 4193edf8cc6..4a5e3c8deaa 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -100,10 +100,11 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %1 : tensor } - // CHECK-LABEL: func @shape_from_if_to_branch_functions - func @shape_from_if_to_branch_functions(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> { - %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], _xla_propagate_compile_time_consts = true, device = "", else_branch = @if_else_branch, is_stateless = true, name = "if", then_branch = @if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> - return %0 : tensor<1x2x3xf32> + // CHECK-LABEL: func @shape_from_if_to_branch_functions_to_results + // CHECK-SAME: (%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + func @shape_from_if_to_branch_functions_to_results(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<*xf32> { + %0 = "tf.If"(%arg0, %arg1) {Tcond = i1, Tin = ["tfdtype$DT_FLOAT"], Tout = ["tfdtype$DT_FLOAT"], else_branch = @if_else_branch, is_stateless = true, name = "if", then_branch = @if_then_branch} : (tensor, tensor<1x2x3xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> } // CHECK-LABEL: func @if_then_branch @@ -124,6 +125,27 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %0 : tensor<*xf32> } + // Verify shape propagation from function arg -> if region body -> if region output -> function return type + // CHECK-LABEL: shape_from_if_to_region_bodies_to_output + // CHECK-SAME: -> tensor<1x2x3xf32> + func @shape_from_if_to_region_bodies_to_output(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<*xf32> { + %unshaped = "tf.Cast"(%arg1) : (tensor<1x2x3xf32>) -> tensor<*xf32> + %0 = "tf.IfRegion"(%arg0) ({ + // CHECK: "tf.Add"{{.+}}(tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + // CHECK: "tf.Yield"{{.+}}(tensor<1x2x3xf32>) -> () + %1 = "tf.Add"(%unshaped, %unshaped) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%1) : (tensor<*xf32>) -> () + }, { + // CHECK: "tf.Sub"{{.+}}(tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + // CHECK: "tf.Yield"{{.+}}(tensor<1x2x3xf32>) -> () + %2 = "tf.Sub"(%unshaped, %unshaped) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%2) : (tensor<*xf32>) -> () + // CHECK: {is_stateless = true} : (tensor) -> tensor<1x2x3xf32> + }) {is_stateless = true} : (tensor) -> tensor<*xf32> + // CHECK: return {{.*}} : tensor<1x2x3xf32> + return %0 : tensor<*xf32> + } + // CHECK-LABEL: func @shape_from_while_to_cond_body_functions func @shape_from_while_to_cond_body_functions(%arg0: tensor<4xf32>, %arg1: tensor>>, %arg2: tensor>>) -> tensor<4xf32> { // CHECK: "tf.While" @@ -169,6 +191,33 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %1, %arg1, %arg2 : tensor<*xf32>, tensor<*x!tf.resource>, tensor>> } + // Verify shape propagation from function arg -> while region cond/body -> while region output -> function return type + // CHECK-LABEL: func @shape_from_while_operands_to_cond_body_to_while_results + // CHECK-SAME: -> tensor<1x2x3xf32> + func @shape_from_while_operands_to_cond_body_to_while_results(%arg0: tensor, %arg1: tensor<1x2x3xf32>) -> tensor<*xf32> { + %unshaped = "tf.Cast"(%arg1) : (tensor<1x2x3xf32>) -> tensor<*xf32> + // CHECK: "tf.WhileRegion" + %0:2 = "tf.WhileRegion"(%arg0, %unshaped) ({ + // CHECK: {{.*}}({{.+}}: tensor, {{.+}}: tensor<1x2x3xf32>): + ^bb0(%carg0: tensor, %carg1: tensor<*xf32>): + %limit = constant dense<5> : tensor + %cond = "tf.NotEqual"(%carg0, %limit) : (tensor, tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + // CHECK: {{.*}}({{.+}}: tensor, {{.+}}: tensor<1x2x3xf32>): + ^bb0(%barg0: tensor, %barg1: tensor<*xf32>): + %one = constant dense<1> : tensor + %sub = "tf.Sub"(%barg0, %one) : (tensor, tensor) -> tensor + // CHECK: "tf.Neg"({{.+}}) : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> + %neg = "tf.Neg"(%barg1) : (tensor<*xf32>) -> tensor<*xf32> + // CHECK: "tf.Yield"{{.+}}, {{.+}}) : (tensor, tensor<1x2x3xf32>) -> () + "tf.Yield"(%sub, %neg) : (tensor, tensor<*xf32>) -> () + // CHECK: {is_stateless = true} : (tensor, tensor<1x2x3xf32>) -> (tensor, tensor<1x2x3xf32>) + }) {is_stateless = true} : (tensor, tensor<*xf32>) -> (tensor, tensor<*xf32>) + // CHECK: return {{.+}}#1 : tensor<1x2x3xf32> + return %0#1 : tensor<*xf32> + } + // CHECK-LABEL: func @shape_from_case_to_branch_functions( // CHECK-SAME: %[[ARG_0:.*]]: tensor, // CHECK-SAME: %[[ARG_1:.*]]: tensor>> @@ -219,7 +268,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @reused_if_then_branch // CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32> - // expected-warning @+1 {{expected control flow function reused_if_then_branch to have exactly 1 use}} + // expected-warning @+1 {{expected control flow function @reused_if_then_branch to have exactly 1 use}} func @reused_if_then_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: return // CHECK-SAME: tensor<*xf32> @@ -228,7 +277,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-LABEL: func @reused_if_else_branch // CHECK-SAME: (%arg0: tensor<*xf32>) -> tensor<*xf32> - // expected-warning @+1 {{expected control flow function reused_if_else_branch to have exactly 1 use}} + // expected-warning @+1 {{expected control flow function @reused_if_else_branch to have exactly 1 use}} func @reused_if_else_branch(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: "tf.Identity"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Identity"(%arg0) : (tensor<*xf32>) -> (tensor<*xf32>) @@ -499,4 +548,16 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { %outputs_2 = "tf.TensorSliceDataset"(%outputs_0) {device = "", output_shapes = [#tf.shape<>]} : (tensor<*xf32>) -> tensor return } + + // Test resource result subtypes are propagated to call op results. + // CHECK-LABEL: func @pcall_resource_result + func @pcall_resource_result(%arg0: tensor<*x!tf.resource>>) { + // CHECK: "tf.StatefulPartitionedCall" + // CHECK-SAME: (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> + %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_resource_result_func} : (tensor<*x!tf.resource>>) -> tensor<*x!tf.resource> + return + } + func @pcall_resource_result_func(%arg0: tensor<*x!tf.resource>>) -> tensor<*x!tf.resource>> { + return %arg0 : tensor<*x!tf.resource>> + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 44646690519..20a0e22c48e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -191,6 +191,24 @@ func @testMul(%arg0: tensor<2xui16>) -> (tensor<2xui16>) { // ----- +// Test error message for incompatible element types. +func @testIncompatibleElementTypes(%arg0: tensor<3x2xf32>, %arg1: tensor<3x2xf64>) -> (tensor<3x2xf32>) { + // expected-error @+1 {{'tf.Mul' op requires compatible element types for all operands and results}} + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<3x2xf32>, tensor<3x2xf64>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} + +// ----- + +// Test error message for incompatible element types. +func @testIncompatibleElementTypes(%arg0: tensor<3x2xf32>, %arg1: tensor<3x2xf32>) -> (tensor<3x2xf64>) { + // expected-error @+1 {{'tf.Mul' op requires compatible element types for all operands and results}} + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf64> + return %0 : tensor<3x2xf64> +} + +// ----- + // CHECK-LABEL: func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) func @testReshape(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<10000xf32>, %arg3: tensor<*xi32>) -> (tensor<100x100xf32>, tensor<*xf32>, tensor<10000xf32>, tensor<100x100xf32>, tensor<*xf32>, tensor<*xf32>) { %shape1 = constant dense<100> : tensor<2xi32> @@ -2026,6 +2044,71 @@ func @testTranspose(tensor<2x3xf32>) -> tensor<3x2xf32> { // ----- +// Test tf.Transpose with partial unknown shape +// CHECK-LABEL: testTranspose +func @testTranspose(tensor<2x?xf32>) -> tensor { +^bb0(%arg0: tensor<2x?xf32>): + %cst = constant dense<[1, 0]> : tensor<2xi32> + %0 = "tf.Transpose"(%arg0, %cst) {T = "tfdtype$DT_FLOAT", Tperm = "tfdtype$DT_INT32"} : (tensor<2x?xf32>, tensor<2xi32>) -> tensor + return %0 : tensor +} + +// ----- + +// Test tf.Transpose with different partial unknown shape +// CHECK-LABEL: testTranspose +func @testTranspose(tensor<2x?x?xf32>) -> tensor<3x?x2xf32> { +^bb0(%arg0: tensor<2x?x?xf32>): + %cst = constant dense<[2, 1, 0]> : tensor<3xi32> + %0 = "tf.Transpose"(%arg0, %cst) {T = "tfdtype$DT_FLOAT", Tperm = "tfdtype$DT_INT32"} : (tensor<2x?x?xf32>, tensor<3xi32>) -> tensor<3x?x2xf32> + return %0 : tensor<3x?x2xf32> +} + +// ----- + +// Test tf.Transpose with invalid rank of perm +func @testTranspose(tensor<2x3xf32>, tensor<1x2xi32>) -> tensor<3x2xf32> { +^bb0(%arg0: tensor<2x3xf32>, %arg1: tensor<1x2xi32>): + // expected-error @+1 {{expected perm to be a 1-D Tensor, got perm of rank 2}} + %0 = "tf.Transpose"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", Tperm = "tfdtype$DT_INT32"} : (tensor<2x3xf32>, tensor<1x2xi32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} + +// ----- + +// Test tf.Transpose with invalid size of perm +func @testTranspose(tensor<2x3xf32>) -> tensor<3x2xf32> { +^bb0(%arg0: tensor<2x3xf32>): + %cst = constant dense<[1, 0, 2]> : tensor<3xi32> + // expected-error @+1 {{expected perm to be a 1-D Tensor of size equal to the rank of x, got perm of size 3, and x of rank 2}} + %0 = "tf.Transpose"(%arg0, %cst) {T = "tfdtype$DT_FLOAT", Tperm = "tfdtype$DT_INT32"} : (tensor<2x3xf32>, tensor<3xi32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} + +// ----- + +// Test tf.Transpose with invalid rank of y +func @testTranspose(tensor<2x3xf32>) -> tensor<3x2x1xf32> { +^bb0(%arg0: tensor<2x3xf32>): + %cst = constant dense<[1, 0]> : tensor<2xi32> + // expected-error @+1 {{x should be of the same rank with y, got x of rank 2, and y of rank 3}} + %0 = "tf.Transpose"(%arg0, %cst) {T = "tfdtype$DT_FLOAT", Tperm = "tfdtype$DT_INT32"} : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2x1xf32> + return %0 : tensor<3x2x1xf32> +} + +// ----- + +// Test tf.Transpose with invalid shape of y +func @testTranspose(tensor<2x3x4xf32>) -> tensor<3x2x4xf32> { +^bb0(%arg0: tensor<2x3x4xf32>): + %cst = constant dense<[2, 0, 1]> : tensor<3xi32> + // expected-error @+1 {{requires y.shape[0] (3) to be equal to x.shape[perm[2]] (4)}} + %0 = "tf.Transpose"(%arg0, %cst) {T = "tfdtype$DT_FLOAT", Tperm = "tfdtype$DT_INT32"} : (tensor<2x3x4xf32>, tensor<3xi32>) -> tensor<3x2x4xf32> + return %0 : tensor<3x2x4xf32> +} + +// ----- + // Test invalid tf.Less func @testLess(tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> { ^bb0(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>): diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir index 2f034f1bfae..0e9814de137 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_executor_ops_invalid.mlir @@ -232,6 +232,20 @@ func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { // ----- +// Check that an island body doesn't have any block arguments. +func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { + tf_executor.graph { + "tf_executor.island"() ({ + // expected-error@-1 {{expects body without any arguments}} + ^entry(%arg: tensor<2xi32>): + tf_executor.yield + }) : () -> (!tf_executor.control) + } + return +} + +// ----- + // Check that an island body can't be empty. func @invalid_island(%arg0: tensor<*xf32>, %ctl: !tf_executor.control) { tf_executor.graph { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py index 209ed3492e8..19e7a90c1e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/control_flow_upgrade_legacy_v1.py @@ -33,9 +33,10 @@ from tensorflow.python.ops import control_flow_ops def Test(): data = tf.constant([1, 2, 3, 4, 5, 6]) - zero = tf.convert_to_tensor(0) - one = tf.convert_to_tensor(1) - less_op = tf.less(zero, one) + # Create placeholders to prevent constant folding. + x_op = tf.placeholder(dtype=tf.int32) + y_op = tf.placeholder(dtype=tf.int32) + less_op = tf.less(x_op, y_op) switch_op = control_flow_ops.switch(data, less_op) merge_op = control_flow_ops.merge(switch_op)[0] result = tf.transpose(merge_op) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py index 7e86953eb8f..4cb931253b3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_asset_v1.py @@ -27,13 +27,15 @@ import tensorflow.compat.v1 as tf from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 # CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> () -# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset:.*]]"} +# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset1:__tf_saved_model_asset1_.*]]"} +# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset0:__tf_saved_model_asset0_.*]]"} # CHECK: func [[init]] -# CHECK-SAME: [[ARG:%.*]]: tensor {tf_saved_model.bound_input = @[[asset]]} +# CHECK-SAME: [[ARG0:%.*]]: tensor {tf_saved_model.bound_input = @[[asset0]]} +# CHECK-SAME: [[ARG1:%.*]]: tensor {tf_saved_model.bound_input = @[[asset1]]} # CHECK-NEXT: [[R0:%.*]] = "tf.HashTableV2"() # CHECK-SAME: shared_name = "[[hash_table:.*]]" -# CHECK-NEXT: "tf.InitializeTableFromTextFileV2"([[R0]], [[ARG]]) +# CHECK-NEXT: "tf.InitializeTableFromTextFileV2"([[R0]], [[ARG0]]) def write_vocabulary_file(vocabulary): @@ -48,11 +50,16 @@ def write_vocabulary_file(vocabulary): def test(): + vocabulary_file = write_vocabulary_file(['cat', 'is', 'on', 'the', 'mat']) table_initializer = tf.lookup.TextFileInitializer( - write_vocabulary_file(['cat', 'is', 'on', 'the', 'mat']), tf.string, - tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64, + vocabulary_file, tf.string, tf.lookup.TextFileIndex.WHOLE_LINE, tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER) + # Incur another bound_input on the asset, but with a different sym_name, i.e., + # __tf_saved_model_asset1_tokens.txt vs. __tf_saved_model_asset0_tokens.txt. table = tf.lookup.StaticVocabularyTable(table_initializer, num_oov_buckets=10) + vocab_file_tensor = tf.convert_to_tensor(vocabulary_file, tf.string, + name='asset_filepath') + tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, vocab_file_tensor) x = tf.placeholder(tf.string, shape=(), name='input') r = table.lookup(x) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_deduplicate_bound_input_bindings.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_deduplicate_bound_input_bindings.mlir new file mode 100644 index 00000000000..22fd3d86068 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_deduplicate_bound_input_bindings.mlir @@ -0,0 +1,33 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-saved-model-dedup-bound-input-binding-pass | FileCheck %s + +module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} { + // Test case: Remove duplicate bound_input symbols. + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.0> : tensor } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "w", type = tensor, value = dense<43.0> : tensor } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "x", type = tensor, value = dense<44.0> : tensor } : () -> () + // CHECK: func @f + // CHECK: %arg0: tensor>> {tf_saved_model.bound_input = @v} + // CHECK: %arg1: tensor>> {tf_saved_model.bound_input = @w} + // CHECK: %arg2: tensor>> {tf_saved_model.bound_input = @x} + // CHECK-NOT: %arg3 + // CHECK-NOT: %arg4 + func @f( + %arg0: tensor>> {tf_saved_model.bound_input = @v}, + %arg1: tensor>> {tf_saved_model.bound_input = @w}, + %arg2: tensor>> {tf_saved_model.bound_input = @v}, + %arg3: tensor>> {tf_saved_model.bound_input = @x}, + %arg4: tensor>> {tf_saved_model.bound_input = @v} + ) attributes {tf_saved_model.exported_names = ["f"]} { + // CHECK: "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + // CHECK: "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor + // CHECK: "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + // CHECK: "tf.ReadVariableOp"(%arg2) : (tensor>>) -> tensor + // CHECK: "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + %val0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + %val1 = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor + %val2 = "tf.ReadVariableOp"(%arg2) : (tensor>>) -> tensor + %val3 = "tf.ReadVariableOp"(%arg3) : (tensor>>) -> tensor + %val4 = "tf.ReadVariableOp"(%arg4) : (tensor>>) -> tensor + return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir index 7156a1fab63..d2c5509b52d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir @@ -76,3 +76,16 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} } } + +// ----- + +module attributes {tf_saved_model.semantics, tf_saved_model.under_construction} { + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.0> : tensor } : () -> () + // CHECK: func @f + func @f( + %arg0: tensor>> {tf_saved_model.bound_input = @v}, + %arg1: tensor>> {tf_saved_model.bound_input = @v} + ) attributes {tf_saved_model.exported_names = ["f"]} { + return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir index dcb889ff99e..714c8908825 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir @@ -400,3 +400,17 @@ module attributes {tf_saved_model.semantics} { } } + +// ----- + +module attributes {tf_saved_model.semantics} { + + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.0> : tensor } : () -> () + // expected-error@+1 {{duplicate 'tf_saved_model.bound_input' binding}} + func @f( + %arg0: tensor>> {tf_saved_model.bound_input = @v}, + %arg1: tensor>> {tf_saved_model.bound_input = @v} + ) attributes {tf_saved_model.exported_names = ["f"]} { + return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir index 43be8743e51..1e308b42bfc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir @@ -20,8 +20,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr {T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE"], body = @while_body_7560, - cond = @while_cond_7550, device = "", is_stateless = false, - output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]} + cond = @while_cond_7550, device = "", is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, @@ -217,8 +216,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr {T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE"], body = @while_body_7560, - cond = @while_cond_7550, device = "", is_stateless = false, - output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]} + cond = @while_cond_7550, device = "", is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, @@ -305,8 +303,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr {T = ["tfdtype$DT_INT32", "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE", "tfdtype$DT_RESOURCE"], body = @while_body_7560, - cond = @while_cond_7550, device = "", is_stateless = false, - output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>]} + cond = @while_cond_7550, device = "", is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index 208146a1226..1f516a25824 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-outside-compilation | FileCheck %s +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-extract-outside-compilation | FILECHECK_OPTS="" FileCheck %s // Tests that missing `_xla_outside_compilation` attribute value results in an error. @@ -143,14 +143,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK-SAME: key = "host_compute_channel_cluster1_args" // CHECK: "tf.B"(%[[RECV_OUTPUT]]) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -172,15 +172,17 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1_args" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"() // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"() - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"() + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" // CHECK: "tf.C"(%[[HOST_OUTPUT]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -203,14 +205,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]]) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" // CHECK: tf_device.return %[[HOST_OUTPUT]] %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -233,15 +236,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]]) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" // CHECK: "tf.C"(%[[HOST_OUTPUT]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -264,16 +267,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) // CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" - // CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" // CHECK: "tf.D"(%[[HOST_OUTPUT]]#0) // CHECK: "tf.E"(%[[HOST_OUTPUT]]#1) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { @@ -299,24 +302,24 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT2:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT2:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT2]]) // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]]) // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster2" + // CHECK-SAME: key = "host_compute_channel_cluster2_retvals" // CHECK: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT1]]) // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]]) // CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]]) - // CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._HostComputeMlir"(%[[C_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster2" + // CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[C_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster2_retvals" // CHECK: "tf.E"(%[[HOST_OUTPUT2]]) %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { @@ -341,14 +344,14 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK-SAME: key = "host_compute_channel_cluster1_args" // CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]]) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -370,22 +373,22 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]]) - // CHECK-SAME: key = "host_compute_channel_cluster2" + // CHECK-SAME: key = "host_compute_channel_cluster2_args" // CHECK: "tf.D"(%[[RECV_OUTPUT_2]]) // CHECK: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK-SAME: key = "host_compute_channel_cluster1_args" // CHECK: "tf.B"(%[[RECV_OUTPUT_1]]) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" - // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]]) + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" // CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C" - // CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster2" + // CHECK: "tf._XlaHostComputeMlir"(%[[C_OUTPUT]]) + // CHECK-SAME: send_key = "host_compute_channel_cluster2_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -408,16 +411,16 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK-NEXT: "tf_device.launch" - // CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir" + // CHECK: %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK-SAME: key = "host_compute_channel_cluster1_args" // CHECK: "tf.C"(%[[RECV_OUTPUT]]#0) // CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0) // CHECK: "tf_device.cluster" // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" - // CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) - // CHECK-SAME: key = "host_compute_channel_cluster1" + // CHECK: "tf._XlaHostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]]) + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { %2 = "tf_device.cluster"() ( { %3 = "tf.A"() : () -> (tensor) @@ -453,4 +456,236 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor } return %1 : tensor } + + // Tests extraction of a single outside compiled cluster inside a tf.IfRegion op. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_if + func @outside_compiled_ops_inside_tf_if(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1) + // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-NEXT: "tf.Yield"() : () -> () + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: tpu_core = 0 + // CHECK-NEXT: "tf.Yield"() : () -> () + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + "tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> () + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.IfRegion + // op with return values. + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_if_with_return_values + func @outside_compiled_ops_inside_tf_if_with_return_values( + %arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK-NEXT: %[[ARG_RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[ARG_RECV_OUTPUT]]#0, %[[ARG_RECV_OUTPUT]]#1) + // CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-NEXT: "tf.Yield"() : () -> () + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK: %[[HOST_COMPUTE_OUT:[0-9]*]] = "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: tpu_core = 0 + // CHECK-NEXT: "tf.Yield"(%[[HOST_COMPUTE_OUT]]) + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + %7 = "tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor, tensor) -> (tensor) + "tf.Yield"(%7) : (tensor) -> () + }, { + + %8 = "tf.F"() : () -> (tensor) + "tf.Yield"(%8) : (tensor) -> () + }) { is_stateless = false} : (tensor) -> (tensor) + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a tf.IfRegion op without external inputs/outputs + + // CHECK-LABEL: func @outside_compiled_ops_inside_tf_if_without_input_outputs + func @outside_compiled_ops_inside_tf_if_without_input_outputs( + %arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK: "tf.D" + // CHECK-NEXT: "tf.Yield"() : () -> () + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf.XlaSendToHost"(%6) {key = "if_predicate_channel_cluster1_0"} + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"() : () -> () + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> () + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of a single outside compiled cluster inside a nested + // tf.IfRegion op. + + // CHECK-LABEL: func @outside_compiled_ops_inside_nested_if + func @outside_compiled_ops_inside_nested_if(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[PREDICATE_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE_RECV_OUTPUT]]) + // CHECK-NEXT: %[[PREDICATE2_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_1" + // CHECK-NEXT: tf.IfRegion"(%[[PREDICATE2_RECV_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"() : () -> () + // CHECK: %[[ARG_RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK: "tf.D"(%[[ARG_RECV_OUTPUT]]) + // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK-NEXT: "tf.Yield"() : () -> () + + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) {key = "if_predicate_channel_cluster1_0"} + // CHECK-NEXT: tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK: %[[H_OUTPUT:[0-9]*]] = "tf.H"(%[[B_OUTPUT]]) + // CHECK: "tf.XlaSendToHost"(%[[H_OUTPUT]]) {key = "if_predicate_channel_cluster1_1"} + // CHECK-NEXT: tf.IfRegion"(%[[H_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"() : () -> () + // CHECK: %[[I_OUTPUT:[0-9]*]] = "tf.I"(%[[H_OUTPUT]]) + // CHECK: "tf._XlaHostComputeMlir"(%[[I_OUTPUT]]) + // CHECK-NEXT: "tf.Yield"() : () -> () + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + %7 = "tf.H"(%4) : (tensor) -> (tensor) + + "tf.IfRegion"(%7)({ + "tf.Yield"() : () -> () + }, + { + %8 = "tf.I"(%7) : (tensor) -> (tensor) + "tf.D"(%8) {_xla_outside_compilation = "cluster1"} : (tensor) -> () + "tf.Yield"() : () -> () + }) { is_stateless = false} : (tensor) -> () + + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) { is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index fa70ca85419..2a0091ce9bf 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-rewrite -tpu_compile_metadata_debug | FileCheck %s +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-rewrite -tpu_compile_metadata_debug | FILECHECK_OPTS="" FileCheck %s // Tests module with missing `tf.versions` attribute. @@ -1256,21 +1256,21 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK: "tf._TPUCompileMlir" // CHECK: "tf.TPUCompileSucceededAssert" // CHECK: "tf_device.parallel_execute" - // CHECK-NOT:"tf._TPUCompileMlir" + // CHECK-NOT:"tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: "tf.D"(%[[COMPILE_OUTPUT]]#1 // CHECK: "tf.TPUExecute" - // CHECK-NOT:"tf._TPUCompileMlir" + // CHECK-NOT:"tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: "tf.E"(%[[COMPILE_OUTPUT]]#1 %3 = "tf_device.parallel_execute"() ( { - %status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor, tensor) - "tf.D"(%program) : (tensor) -> () + %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor + "tf.D"(%program) : (tensor) -> () tf_device.return }, { %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor tf_device.return %4 : tensor }, { - %status, %program = "tf._TPUCompileMlir"() {metadata = "...", mlir_module = "..."} : () -> (tensor, tensor) - "tf.E"(%program) : (tensor) -> () + %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor + "tf.E"(%program) : (tensor) -> () tf_device.return }) : () -> (tensor) tf_device.return %3 : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir index 199426b1aa9..280986a7ee1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_space_to_depth_pass.mlir @@ -7,7 +7,7 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0" %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %1 = "tf.Const"() {value = dense<2> : tensor} : () -> tensor %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %3:10 = "tf.While"(%2, %1, %2, %0, %1, %arg2, %arg4, %arg5, %arg6, %arg7) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_2710, cond = @while_cond_2700, device = "", is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>, #tf.shape<>], parallel_iterations = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) + %3:10 = "tf.While"(%2, %1, %2, %0, %1, %arg2, %arg4, %arg5, %arg6, %arg7) {_lower_using_switch_merge = true, _num_original_outputs = 10 : i64, _read_only_resource_inputs = [], body = @while_body_2710, cond = @while_cond_2700, device = "", is_stateless = false, parallel_iterations = 10 : i64} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor>>, tensor>>, tensor>>, tensor>>) return } // CHECK-LABEL: func @while_body_2710 diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir index b77e4b1fbd0..47374b7f7d4 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir @@ -9,16 +9,15 @@ // CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor // CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor // CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor -// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor func @check_enqueue_ops_update_for_eval(%arg0: tensor, %arg1: tensor, %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, - %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + %arg6: tensor, %arg7: tensor) -> () { // CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"() %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> - %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor - // CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_7]]) - "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () + // CHECK: %[[CONST_MODE:[a-z0-9]*]] = "tf.Const"() {_xla_outside_compilation = "0", value = dense<"inference"> : tensor} : () -> tensor + // CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]]) + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () %2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) return } @@ -34,20 +33,19 @@ func @check_enqueue_ops_update_for_eval(%arg0: tensor, %arg1: tensor // CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor // CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor -// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor func @check_enqueue_ops_update_for_training(%arg0: tensor, %arg1: tensor, %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, - %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + %arg6: tensor, %arg7: tensor) -> () { // CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"() %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> - %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor %2 = "tf.Const"() {value = dense<0.0> : tensor<2x2xf32>} : () -> tensor<2x2xf32> %3 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> "tf.SendTPUEmbeddingGradients"(%2, %3) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> () - // CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_6]]) - "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () + // CHECK: %[[CONST_MODE:[a-z0-9]*]] = "tf.Const"() {_xla_outside_compilation = "0", value = dense<"train"> : tensor} : () -> tensor + // CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[CONST_MODE]]) + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %arg7) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () %4:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) return } @@ -65,15 +63,3 @@ func @check_enqueue_ops_with_different_attr_disallowed(%arg0: tensor, % return } -// ----- - -func @check_embedding_ops_with_missing_attribute_disallowed(%arg0: tensor, %arg1: tensor, - %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, - %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { - %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> - %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor - "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () - // expected-error @+1 {{'tf.RecvTPUEmbeddingActivations' op requires attribute '_tpu_embedding_layer'}} - %2:2 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) - return -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir index 5a3f0b6e997..7cf5f19523d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/unroll-batch-matmul.mlir @@ -67,41 +67,35 @@ func @batchMatMulV2FlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) return %0 : tensor<3x4x6xf32> // CHECK-LABEL: batchMatMulV2FlatInput - // CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>} // CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>} // CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} - // CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>} - // CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} - // CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>} - // CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>} - // CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>} - // CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} - // CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>} + // CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} + // CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>} + // CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} - // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32> - // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> - // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> - // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> - // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> - // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> - // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v0:.*]] = "tf.Slice"(%arg0, %[[cst_2]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Reshape"(%[[v0]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v2:.*]] = "tf.Slice"(%arg0, %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v4:.*]] = "tf.Slice"(%arg0, %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v5:.*]] = "tf.Reshape"(%[[v4]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> - // CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32> - // CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> - // CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> - // CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> - // CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v6:.*]] = "tf.Slice"(%arg1, %[[cst_2]], %[[cst_5]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v7:.*]] = "tf.Reshape"(%[[v6]], %[[cst_6]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v8:.*]] = "tf.Slice"(%arg1, %[[cst_3]], %[[cst_5]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_6]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v10:.*]] = "tf.Slice"(%arg1, %[[cst_4]], %[[cst_5]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_6]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[mm0:.*]] = "tf.MatMul"(%[[v1]], %[[v7]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[mm1:.*]] = "tf.MatMul"(%[[v3]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[mm2:.*]] = "tf.MatMul"(%[[v5]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> - // CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32> + // CHECK: %[[v17:.*]] = "tf.Pack"(%[[mm0]], %[[mm1]], %[[mm2]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> - // CHECK: return %[[v18]] : tensor<3x4x6xf32> + // CHECK: return %[[v17]] : tensor<3x4x6xf32> } // ----- @@ -184,41 +178,35 @@ func @batchMatMulFlatInput(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) - return %0 : tensor<3x4x6xf32> // CHECK-LABEL: batchMatMulFlatInput - // CHECK: %[[cst:.*]] = "tf.Const"() {value = dense<[3, 4, 5]> : tensor<3xi64>} // CHECK: %[[cst_0:.*]] = "tf.Const"() {value = dense<[1, 4, 5]> : tensor<3xi64>} // CHECK: %[[cst_1:.*]] = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi64>} - // CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<[3, 5, 6]> : tensor<3xi64>} - // CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} - // CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>} - // CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>} - // CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>} - // CHECK: %[[cst_7:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} - // CHECK: %[[cst_8:.*]] = "tf.Const"() {value = dense<[3, 4, 6]> : tensor<3xi64>} + // CHECK: %[[cst_2:.*]] = "tf.Const"() {value = dense<0> : tensor<3xi64>} + // CHECK: %[[cst_3:.*]] = "tf.Const"() {value = dense<[1, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_4:.*]] = "tf.Const"() {value = dense<[2, 0, 0]> : tensor<3xi64>} + // CHECK: %[[cst_5:.*]] = "tf.Const"() {value = dense<[1, 5, 6]> : tensor<3xi64>} + // CHECK: %[[cst_6:.*]] = "tf.Const"() {value = dense<[5, 6]> : tensor<2xi64>} - // CHECK: %[[v0:.*]] = "tf.Reshape"(%arg0, %[[cst]]) : (tensor<3x4x5xf32>, tensor<3xi64>) -> tensor<3x4x5xf32> - // CHECK: %[[v1:.*]] = "tf.Slice"(%[[v0]], %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> - // CHECK: %[[v2:.*]] = "tf.Reshape"(%[[v1]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> - // CHECK: %[[v3:.*]] = "tf.Slice"(%[[v0]], %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> - // CHECK: %[[v4:.*]] = "tf.Reshape"(%[[v3]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> - // CHECK: %[[v5:.*]] = "tf.Slice"(%[[v0]], %[[cst_5]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> - // CHECK: %[[v6:.*]] = "tf.Reshape"(%[[v5]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v0:.*]] = "tf.Slice"(%arg0, %[[cst_2]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v1:.*]] = "tf.Reshape"(%[[v0]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v2:.*]] = "tf.Slice"(%arg0, %[[cst_3]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v3:.*]] = "tf.Reshape"(%[[v2]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> + // CHECK: %[[v4:.*]] = "tf.Slice"(%arg0, %[[cst_4]], %[[cst_0]]) : (tensor<3x4x5xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x4x5xf32> + // CHECK: %[[v5:.*]] = "tf.Reshape"(%[[v4]], %[[cst_1]]) : (tensor<1x4x5xf32>, tensor<2xi64>) -> tensor<4x5xf32> - // CHECK: %[[v7:.*]] = "tf.Reshape"(%arg1, %[[cst_2]]) : (tensor<3x5x6xf32>, tensor<3xi64>) -> tensor<3x5x6xf32> - // CHECK: %[[v8:.*]] = "tf.Slice"(%[[v7]], %[[cst_3]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> - // CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[v10:.*]] = "tf.Slice"(%[[v7]], %[[cst_4]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> - // CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[v12:.*]] = "tf.Slice"(%[[v7]], %[[cst_5]], %[[cst_6]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> - // CHECK: %[[v13:.*]] = "tf.Reshape"(%[[v12]], %[[cst_7]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v6:.*]] = "tf.Slice"(%arg1, %[[cst_2]], %[[cst_5]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v7:.*]] = "tf.Reshape"(%[[v6]], %[[cst_6]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v8:.*]] = "tf.Slice"(%arg1, %[[cst_3]], %[[cst_5]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v9:.*]] = "tf.Reshape"(%[[v8]], %[[cst_6]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> + // CHECK: %[[v10:.*]] = "tf.Slice"(%arg1, %[[cst_4]], %[[cst_5]]) : (tensor<3x5x6xf32>, tensor<3xi64>, tensor<3xi64>) -> tensor<1x5x6xf32> + // CHECK: %[[v11:.*]] = "tf.Reshape"(%[[v10]], %[[cst_6]]) : (tensor<1x5x6xf32>, tensor<2xi64>) -> tensor<5x6xf32> - // CHECK: %[[v14:.*]] = "tf.MatMul"(%[[v2]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[v15:.*]] = "tf.MatMul"(%[[v4]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[v16:.*]] = "tf.MatMul"(%[[v6]], %[[v13]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[mm0:.*]] = "tf.MatMul"(%[[v1]], %[[v7]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[mm1:.*]] = "tf.MatMul"(%[[v3]], %[[v9]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> + // CHECK: %[[mm2:.*]] = "tf.MatMul"(%[[v5]], %[[v11]]) {transpose_a = false, transpose_b = false} : (tensor<4x5xf32>, tensor<5x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[v17:.*]] = "tf.Pack"(%[[v14]], %[[v15]], %[[v16]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> - // CHECK: %[[v18:.*]] = "tf.Reshape"(%[[v17]], %[[cst_8]]) : (tensor<3x4x6xf32>, tensor<3xi64>) -> tensor<3x4x6xf32> + // CHECK: %[[v17:.*]] = "tf.Pack"(%[[mm0]], %[[mm1]], %[[mm2]]) {axis = 0 : i64} : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) -> tensor<3x4x6xf32> - // CHECK: return %[[v18]] : tensor<3x4x6xf32> + // CHECK: return %[[v17]] : tensor<3x4x6xf32> } // ----- diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h new file mode 100644 index 00000000000..599a8df63d7 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ATTRIBUTE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ATTRIBUTE_UTILS_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Copies attributes that satisfy the given predicate from `from` to `to`. +template +void CopyAttributes(Operation *from, Operation *to, Predicate P) { + for (const NamedAttribute &attr : from->getAttrs()) + if (P(attr)) to->setAttr(attr.first, attr.second); +} + +// Copies attributes whose name begins with an _ from `from` to `to`. +inline void CopyUnderscoredAttributes(Operation *from, Operation *to) { + CopyAttributes(from, to, [](const NamedAttribute &attr) { + return attr.first.strref().front() == '_'; + }); +} + +} // namespace TF +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ATTRIBUTE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 1963931b497..2a5c8a05ef3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -40,14 +40,17 @@ void EnableLogging(PassManager *pm) { namespace TFTPU { namespace { void AddGraphExportLoweringPasses(OpPassManager &pm) { + auto add_pass = [&](std::unique_ptr pass) { + pm.addNestedPass(std::move(pass)); + pm.addPass(CreateBreakUpIslandsPass()); + }; + pm.addNestedPass(CreateFunctionalToExecutorDialectConversionPass()); - pm.addNestedPass(CreateBreakUpIslandsPass()); - pm.addNestedPass(TFDevice::CreateReplicateToIslandPass()); - pm.addNestedPass(CreateBreakUpIslandsPass()); - pm.addNestedPass(TFDevice::CreateParallelExecuteToIslandsPass()); - pm.addNestedPass(CreateBreakUpIslandsPass()); - pm.addNestedPass(TFDevice::CreateLaunchToDeviceAttributePass()); - pm.addNestedPass(CreateBreakUpIslandsPass()); + add_pass(TFDevice::CreateParallelizeEmbeddingParamsOpsPass()); + pm.addPass(TFDevice::CreateReplicateToIslandPass()); + pm.addPass(CreateBreakUpIslandsPass()); + add_pass(TFDevice::CreateParallelExecuteToIslandsPass()); + add_pass(TFDevice::CreateLaunchToDeviceAttributePass()); } tensorflow::Status RunTPUBridge( @@ -80,14 +83,23 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { // Run shape inference so that tf_executor/tf_device ops created later will // likely to inherit more concrete types. pm.addPass(TF::CreateTFShapeInferencePass()); - OpPassManager &func_pm = pm.nest(); - func_pm.addPass(CreateTPUClusterFormationPass()); - // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass - // because DecomposeResourceOpsPass uses pattern rewriter which hoists - // changed constants out of tf_device.Launch. - func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass()); - func_pm.addPass(CreateTPUHostComputationExpansionPass()); + // Encode this in its own scope so that func_pm is not mistakenly used + // later on. + { + OpPassManager &func_pm = pm.nest(); + func_pm.addPass(CreateTPUClusterFormationPass()); + // Place DecomposeResourceOpsPass before TFExecutorConstantSinking pass + // because DecomposeResourceOpsPass uses pattern rewriter which hoists + // changed constants out of tf_device.Launch. + func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass()); + func_pm.addPass(CreateTPUHostComputationExpansionPass()); + func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass()); + } + pm.addPass(TF::CreateTFFunctionalControlFlowToRegions()); + pm.addPass(mlir::createInlinerPass()); pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass()); + pm.addPass(TF::CreateTFRegionControlFlowToFunctional()); + // Run another shape inference pass because resource decomposition might have // created new partial types. pm.addPass(TF::CreateTFShapeInferencePass()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index 9d72284da91..d5b7eb7a739 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -150,6 +150,7 @@ def LogToLog1p : Pat< // LogicalNot op patterns. //===----------------------------------------------------------------------===// +// TODO(ezhulenev): Generalize this pattern for all involutions. def LogicalNotNested : Pat<(TF_LogicalNotOp (TF_LogicalNotOp $arg)), (replaceWithValue $arg)>; @@ -187,6 +188,13 @@ def NegNested : Pat<(TF_NegOp (TF_NegOp $arg)), (replaceWithValue $arg)>; def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)), (TF_MulOp $arg0, (TF_RsqrtOp $arg1))>; +// Replace division by a constant with a multiplication by a reciprocal of that +// constant. Floating point division can be ~10x more expensive than a +// multiplication. +def RealDivWithConstDivisor : Pat< + (TF_RealDivOp $arg0, (TF_ConstOp FloatElementsAttr<32>:$value)), + (TF_MulOp $arg0, (TF_ReciprocalOp (TF_ConstOp $value)))>; + //===----------------------------------------------------------------------===// // Reciprocal op patterns. //===----------------------------------------------------------------------===// @@ -201,6 +209,11 @@ def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape), (TF_ReshapeOp $arg, $shape)>; +def IsSame : Constraint>; +def ReshapeToSelfShape : Pat<(TF_ReshapeOp $arg0, (TF_ShapeOp $arg1)), + (replaceWithValue $arg0), + [(IsSame $arg0, $arg1)]>; + //===----------------------------------------------------------------------===// // Select op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index 58c4eac5c95..57a5cd888a1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -77,8 +77,7 @@ Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder, ArrayRef{RankedTensorType::get( {static_cast(buffer_type.getShape().size())}, getElementTypeOrSelf(index.getType()))}, - ArrayRef{index, zeros_tensor, CreateScalarConst(0, builder, loc)}, - ArrayRef{}); + ArrayRef{index, zeros_tensor, CreateScalarConst(0, builder, loc)}); } Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc, @@ -95,15 +94,14 @@ Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc, auto slice = builder.create( loc, ArrayRef{slice_type}, ArrayRef{buffer, GetIndicesForElement(index, buffer, builder, loc), - size_const}, - ArrayRef{}); + size_const}); if (keep_slice_shape) return slice; auto element_type = RankedTensorType::get(buffer_type.getShape().drop_front(), buffer_type.getElementType()); auto reshape = builder.create( loc, ArrayRef{element_type}, - ArrayRef{slice, GetR1Const(element_type.getShape(), builder, loc)}, - ArrayRef{}); + ArrayRef{slice, + GetR1Const(element_type.getShape(), builder, loc)}); return reshape.output(); } @@ -120,15 +118,13 @@ Value SetElement(Value index, Value buffer, Value element, OpBuilder builder, if (element.getType() != slice_type) { update_slice = builder.create( loc, ArrayRef{slice_type}, - ArrayRef{element, GetR1Const(slice_shape, builder, loc)}, - ArrayRef{}); + ArrayRef{element, GetR1Const(slice_shape, builder, loc)}); } return builder .create( loc, ArrayRef{buffer.getType()}, ArrayRef{buffer, update_slice, - GetIndicesForElement(index, buffer, builder, loc)}, - ArrayRef{}) + GetIndicesForElement(index, buffer, builder, loc)}) .output(); } @@ -140,8 +136,7 @@ Value ReshapeScalarToSizeType(OpBuilder builder, Value scalar, Location loc) { auto size_type = GetSizeType(builder); return builder.create( loc, ArrayRef{size_type}, - ArrayRef{scalar, GetR1Const(size_type.getShape(), builder, loc)}, - ArrayRef{}); + ArrayRef{scalar, GetR1Const(size_type.getShape(), builder, loc)}); } LogicalResult CreateInitBufferValue(ArrayRef element_shape, @@ -171,13 +166,12 @@ LogicalResult CreateInitBufferValue(ArrayRef element_shape, if (getElementTypeOrSelf(zero.getType()) != element_dtype) { zero = builder.create( op->getLoc(), ArrayRef{RankedTensorType::get({}, element_dtype)}, - ArrayRef{zero}, ArrayRef{}); + ArrayRef{zero}); } auto buffer_type = RankedTensorType::get(buffer_shape, element_dtype); auto broadcast = builder.create( op->getLoc(), ArrayRef{buffer_type}, - ArrayRef{zero, GetR1Const(buffer_shape, builder, op->getLoc())}, - ArrayRef{}); + ArrayRef{zero, GetR1Const(buffer_shape, builder, op->getLoc())}); *buffer = broadcast.output(); return success(); } @@ -187,14 +181,14 @@ llvm::Optional GetElementTypeFromAccess( llvm::function_ref(Operation*)> infer_from_op) { for (auto& use : collection.getUses()) { if (auto while_op = llvm::dyn_cast(use.getOwner())) { - auto body = module.lookupSymbol(while_op.body()); + auto body = while_op.body_func(); assert(body); auto type_from_body = GetElementTypeFromAccess( body.getArgument(use.getOperandNumber()), module, infer_from_op); if (type_from_body.hasValue()) return type_from_body; } else if (auto if_op = llvm::dyn_cast(use.getOwner())) { - auto then_branch = module.lookupSymbol(if_op.then_branch()); - auto else_branch = module.lookupSymbol(if_op.else_branch()); + auto then_branch = if_op.then_func(); + auto else_branch = if_op.else_func(); assert(then_branch && else_branch); auto type_from_then = GetElementTypeFromAccess( then_branch.getArgument(use.getOperandNumber() - 1), module, @@ -204,18 +198,8 @@ llvm::Optional GetElementTypeFromAccess( else_branch.getArgument(use.getOperandNumber() - 1), module, infer_from_op); if (type_from_else.hasValue()) return type_from_else; - } else if (auto pcall = - llvm::dyn_cast(use.getOwner())) { - if (!pcall.f().isa()) continue; - auto callee = module.lookupSymbol(pcall.f().getRootReference()); - assert(callee); - auto type_from_callee = GetElementTypeFromAccess( - callee.getArgument(use.getOperandNumber()), module, infer_from_op); - if (type_from_callee.hasValue()) return type_from_callee; - } else if (auto spcall = llvm::dyn_cast( - use.getOwner())) { - auto callee = module.lookupSymbol(spcall.f()); - assert(callee); + } else if (auto call = llvm::dyn_cast(use.getOwner())) { + auto callee = dyn_cast(call.resolveCallable()); auto type_from_callee = GetElementTypeFromAccess( callee.getArgument(use.getOperandNumber()), module, infer_from_op); if (type_from_callee.hasValue()) return type_from_callee; @@ -241,27 +225,24 @@ Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) { ArrayRef{getElementTypeOrSelf(local_var.getType()) .cast() .getSubtypes()[0]}, - ArrayRef{local_var}, ArrayRef{}) + ArrayRef{local_var}) .value(); } // Creates an AssignVariableOp on a local variable. TF::AssignVariableOp WriteLocalVariable(Value local_var, Value value, OpBuilder builder, Location loc) { - return builder.create(loc, ArrayRef{}, - ArrayRef{local_var, value}, - ArrayRef{}); + return builder.create( + loc, ArrayRef{}, ArrayRef{local_var, value}); } Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc) { if (getElementTypeOrSelf(a.getType()) == builder.getI1Type()) { return builder.create(loc, ArrayRef{a.getType()}, - ArrayRef{a, b}, - ArrayRef{}); + ArrayRef{a, b}); } return builder.create(loc, ArrayRef{a.getType()}, - ArrayRef{a, b}, - ArrayRef{}); + ArrayRef{a, b}); } namespace { @@ -303,15 +284,13 @@ Value GatherElements(Value indices, Value buffer, OpBuilder builder, return builder.create( loc, ArrayRef{slice_type}, ArrayRef{buffer, GetR1Const(slice_starts, builder, loc), - GetR1Const(result_shape, builder, loc)}, - ArrayRef{}); + GetR1Const(result_shape, builder, loc)}); } auto result_type = RankedTensorType::get(result_shape, buffer_type.getElementType()); return builder.create( loc, ArrayRef{result_type}, - ArrayRef{buffer, indices, CreateScalarConst(0, builder, loc)}, - ArrayRef{}); + ArrayRef{buffer, indices, CreateScalarConst(0, builder, loc)}); } Value ScatterAccumulateElements(Value indices, Value updates, Value buffer, @@ -334,8 +313,7 @@ Value ScatterAccumulateElements(Value indices, Value updates, Value buffer, auto index = builder.create( loc, ArrayRef{GetSizeType(builder)}, ArrayRef{indices, GetR1Const({i}, builder, loc), - GetR1Const({1}, builder, loc)}, - ArrayRef{}); + GetR1Const({1}, builder, loc)}); auto old_slice = GetElement(index, buffer, builder, loc, /*keep_slice_shape=*/true); starts_in_update[0] = i; @@ -344,8 +322,7 @@ Value ScatterAccumulateElements(Value indices, Value updates, Value buffer, builder .create( loc, ArrayRef{old_slice.getType()}, - ArrayRef{updates, update_slice_starts, slice_sizes}, - ArrayRef{}) + ArrayRef{updates, update_slice_starts, slice_sizes}) .output(); slice = AccumulateBuffers(old_slice, slice, builder, loc); buffer = SetElement(index, buffer, slice, builder, loc); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 007baaae433..1429e2b3fd4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -40,7 +40,7 @@ namespace TF { // LINT.IfChange(folding-policy) static bool ShouldBeFolded(Operation* inst) { constexpr int kSizeFactor = 2; - constexpr int64_t kSizeThreshold = (1 << 20); // 128 KB + constexpr int64_t kSizeThreshold = (1 << 21); // 256 KB bool has_unknown_shape = false; auto get_size = [&](TypeRange types) { int64_t size = 0; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc b/tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc new file mode 100644 index 00000000000..c1514dfa357 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc @@ -0,0 +1,65 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/DenseMap.h" +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" + +namespace mlir { +namespace tf_saved_model { +namespace { + +class DedupBoundInputBindingPass + : public PassWrapper { + public: + void runOnFunction() override; +}; + +void DedupBoundInputBindingPass::runOnFunction() { + FuncOp func = getFunction(); + if (!mlir::tf_saved_model::IsExported(func)) return; + llvm::SmallDenseMap unique_bound_inputs; + llvm::SmallVector arg_indices_to_erase; + for (unsigned i = 0, e = func.getNumArguments(); i < e; i++) { + auto attr = func.getArgAttrOfType( + i, "tf_saved_model.bound_input"); + if (!attr) continue; + auto inserted = unique_bound_inputs.insert(std::make_pair(attr, i)); + if (inserted.second) continue; + auto duplicate_arg = func.getArgument(i); + auto original_arg = func.getArgument(unique_bound_inputs[attr]); + duplicate_arg.replaceAllUsesWith(original_arg); + arg_indices_to_erase.push_back(i); + } + func.eraseArguments(arg_indices_to_erase); +} + +} // namespace + +static PassRegistration pass( + "tf-saved-model-dedup-bound-input-binding-pass", + "Remove duplicate 'tf_saved_model.bound_input' bindings."); + +std::unique_ptr> CreateDedupBoundInputBindingPass() { + return std::make_unique(); +} + +} // namespace tf_saved_model +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 1e622a295ec..69dab58c3f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -54,6 +54,8 @@ enum EinsumEquation { TransposeMatMul, BatchMatMulReducedDim, TransposeReducedDim, + FourDReduceLast, + FourDTransposeAll, UnsupportedEquation }; @@ -72,7 +74,7 @@ constexpr int kNumSupportedEquationVariables = 5; // A - E for now. bool tokenizeEquation(const llvm::StringRef& equation, std::vector* tokens) { std::map label_axis_mapping; - int index = 0; + size_t index = 0; int variable_count = 0; llvm::Regex r("[[:alpha:]]"); while (index < equation.size()) { @@ -146,6 +148,14 @@ EinsumEquation parseEquation(const std::vector& eqn) { if (is_equal(eqn, {A, B, C, COMMA, A, B, D, C, ARROW, A, B, D})) { return EinsumEquation::TransposeReducedDim; } + // ABCD,ADBE->ACBE + if (is_equal(eqn, {A, B, C, D, COMMA, A, D, B, E, ARROW, A, C, B, E})) { + return EinsumEquation::FourDReduceLast; + } + // ABCD,AECD->ACEB + if (is_equal(eqn, {A, B, C, D, COMMA, A, E, C, D, ARROW, A, C, E, B})) { + return EinsumEquation::FourDTransposeAll; + } return EinsumEquation::UnsupportedEquation; } @@ -167,7 +177,7 @@ TF::TransposeOp createTransposeOp(Value value, Location loc, auto perm_attr = DenseElementsAttr::get(perm_type, permutation); auto perm_op = rewriter->create(loc, perm_type, perm_attr); std::vector transposed_shape(shape.begin(), shape.end()); - for (int i = 0; i < shape.size(); ++i) { + for (int i = 0, end = shape.size(); i < end; ++i) { transposed_shape[i] = shape[permutation[i]]; } auto transposed_type = @@ -187,7 +197,7 @@ TF::SumOp createSumOp(Value value, Location loc, auto redux_op = rewriter->create(loc, redux_type, redux_attr); std::vector sum_shape(shape.size() - redux_axes.size()); int count = 0; - for (int i = 0; i < shape.size(); ++i) { + for (int i = 0, end = shape.size(); i < end; ++i) { if (std::find(redux_axes.begin(), redux_axes.end(), i) == redux_axes.end()) { sum_shape[count] = shape[i]; @@ -380,6 +390,7 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( auto final_reshape = createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim3}, bmm_element_type, loc, &rewriter); rewriter.replaceOp(op, {final_reshape.getResult()}); + return success(); } if (einsum_eqn == EinsumEquation::TransposeReducedDim) { // Case "BIJ,BINJ->BIN" @@ -404,6 +415,45 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( auto final_reshape = createReshapeOp(bmm_op, {lhs_dim0, lhs_dim1, rhs_dim2}, bmm_element_type, loc, &rewriter); rewriter.replaceOp(op, {final_reshape.getResult()}); + return success(); + } + if (einsum_eqn == EinsumEquation::FourDReduceLast) { + // Case "acbe,aecd->abcd" + const int lhs_dim2 = lhs_shape[2]; + const int rhs_dim0 = rhs_shape[0]; + const int rhs_dim2 = rhs_shape[2]; + const int rhs_dim3 = rhs_shape[3]; + // Transpose RHS + rhs = createTransposeOp(rhs, loc, {0, 2, 1, 3}, &rewriter); + std::vector bmm_shape = {rhs_dim0, rhs_dim2, lhs_dim2, rhs_dim3}; + auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); + auto bmm_op = rewriter.create( + loc, ArrayRef{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + + auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 2, 1, 3}, &rewriter); + rewriter.replaceOp(op, {trans_bmm.getResult()}); + return success(); + } + if (einsum_eqn == EinsumEquation::FourDTransposeAll) { + // Case "aecd,abcd->acbe" + const int lhs_dim0 = lhs_shape[0]; + const int lhs_dim1 = lhs_shape[1]; + const int lhs_dim2 = lhs_shape[2]; + const int rhs_dim1 = rhs_shape[1]; + // Transpose LHS + lhs = createTransposeOp(lhs, loc, {0, 2, 1, 3}, &rewriter); + // Transpose RHS + rhs = createTransposeOp(rhs, loc, {0, 2, 3, 1}, &rewriter); + std::vector bmm_shape = {lhs_dim0, lhs_dim2, lhs_dim1, rhs_dim1}; + auto bmm_type = RankedTensorType::get(bmm_shape, rhs_type.getElementType()); + auto bmm_op = rewriter.create( + loc, ArrayRef{bmm_type}, lhs, rhs, rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)); + + auto trans_bmm = createTransposeOp(bmm_op, loc, {0, 1, 3, 2}, &rewriter); + rewriter.replaceOp(op, {trans_bmm.getResult()}); + return success(); } return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc index 0d72a7638a3..02a2e7efa6f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -185,8 +185,8 @@ IslandOp CreateNewIsland(IslandOp parent, IslandOp child, Operation* old_island = insert_position == kParentIsland ? parent : child; OpBuilder builder(old_island); - auto new_island = builder.create( - old_island->getLoc(), result_types, operands, ArrayRef{}); + auto new_island = + builder.create(old_island->getLoc(), result_types, operands); new_island.body().push_back(new Block); return new_island; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index 9a533798208..f624d6cad58 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -61,11 +61,11 @@ void TPUBridgeExecutorIslandInlining::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "Found call to inline: " << *call_op.getOperation() << "\n"); - FuncOp called_func = dyn_cast_or_null( - symbol_table.lookupSymbolIn(getOperation(), call_op.f())); + auto call_interface = cast(call_op.getOperation()); + auto called_func = + dyn_cast_or_null(call_interface.resolveCallable()); - if (failed(inlineCall(inliner, - cast(call_op.getOperation()), + if (failed(inlineCall(inliner, call_interface, cast(called_func.getOperation()), called_func.getCallableRegion(), /* shouldCloneInlinedRegion = */ false))) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index e04f6bf3daa..a5177fac647 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -105,9 +105,10 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() { // Create the outlined function SmallString<32> name = kOutlinedFuncPrefix; name += llvm::Twine(prefix_id++).str(); - auto outlined_func = OpBuilder(ctx).create( - island_op.getLoc(), name, func_type, ArrayRef()); + auto outlined_func = + OpBuilder(ctx).create(island_op.getLoc(), name, func_type); outlined_symbol_table.insert(outlined_func); + outlined_func.setVisibility(FuncOp::Visibility::Nested); // We will "steal" the body of the island and replace it with a call to the // new function later. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index a0be88cc564..d8678e620f4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -140,10 +140,6 @@ static LogicalResult LowerIfOp(IfOp op) { Value cond_i1 = LowerCondition(loc, op.cond(), &builder); if (!cond_i1) return failure(); - auto module = op_inst->getParentOfType(); - auto then_fn = module.lookupSymbol(op.then_branch()); - auto else_fn = module.lookupSymbol(op.else_branch()); - // Split the basic block before the 'if'. The new dest will be our merge // point. Block* orig_block = op_inst->getBlock(); @@ -161,14 +157,14 @@ static LogicalResult LowerIfOp(IfOp op) { // Set up the 'then' block. Block* then_block = builder.createBlock(merge_block); - Operation* call_op = CallFn(loc, get_operand, then_fn, &builder); + Operation* call_op = CallFn(loc, get_operand, op.then_func(), &builder); auto get_then_result = [&](int i) { return call_op->getResult(i); }; JumpToBlock(loc, get_then_result, merge_block, &builder); // Set up the 'else' block. Block* else_block = builder.createBlock(merge_block); - call_op = CallFn(loc, get_operand, else_fn, &builder); + call_op = CallFn(loc, get_operand, op.else_func(), &builder); auto get_else_result = [&](int i) { return call_op->getResult(i); }; JumpToBlock(loc, get_else_result, merge_block, &builder); @@ -194,9 +190,8 @@ static LogicalResult LowerWhileOp(WhileOp op) { OpBuilder builder(op_inst); - auto module = op_inst->getParentOfType(); - auto cond_fn = module.lookupSymbol(op.cond()); - auto body_fn = module.lookupSymbol(op.body()); + auto cond_fn = op.cond_func(); + auto body_fn = op.body_func(); // Split the block containing the While op into two blocks. One containing // operations before the While op and other containing the rest. Create two diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index 5ab0eda08c6..d23b977f0e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -15,7 +15,7 @@ limitations under the License. // This transformation pass transforms functional control flow operations in the // TensorFlow dialect to their region based counterparts, i.e., -// tf.If -> tf.IfRegion +// tf.If -> tf.IfRegion and tf.While -> tf.WhileRegion #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -31,8 +31,11 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#define DEBUG_TYPE "tf-functional-cf-to-region" + namespace mlir { namespace TF { @@ -44,24 +47,36 @@ struct FunctionalControlFlowToRegions void runOnOperation() override; }; -// Create a call to function `fn` with arguments `args` and return the CallOp. -// The arguments are cast to the required type before the call. -CallOp CreateCall(Location loc, Operation::operand_range args, FuncOp fn, - OpBuilder* builder) { - FunctionType fn_type = fn.getType(); - llvm::SmallVector operands; - int num_operands = fn_type.getNumInputs(); - operands.reserve(num_operands); - for (const auto& ArgAndType : zip(args, fn_type.getInputs())) { +// Creates a call to function `func` in region `caller_region`. Use `args` as +// the call arguments, and terminate the region with a yield. The arguments are +// cast to the required type before the call. `use_region_args` control whether +// the input arguments are used as is (for IfOp) or block arguments of the same +// type as the input arguments are created and then used as call arguments (for +// While). +void CreateCall(Operation* op, FuncOp func, Region& caller_region, + ValueRange args, bool use_region_args) { + assert(caller_region.empty() && + "Expected empty region for newly created ops"); + OpBuilder builder(caller_region); + Block* entry = builder.createBlock(&caller_region); + + if (use_region_args) { + entry->addArguments(args.getType()); + args = entry->getArguments(); + } + llvm::SmallVector casted_args; + casted_args.reserve(func.getNumArguments()); + for (const auto& ArgAndType : zip(args, func.getType().getInputs())) { Value arg = std::get<0>(ArgAndType); Type expected_type = std::get<1>(ArgAndType); if (arg.getType() != expected_type) { - arg = builder->create(loc, expected_type, arg, - /*Truncate=*/builder->getBoolAttr(false)); + arg = builder.create(op->getLoc(), expected_type, arg, + /*Truncate=*/builder.getBoolAttr(false)); } - operands.push_back(arg); + casted_args.push_back(arg); } - return builder->create(loc, fn, operands); + auto call = builder.create(op->getLoc(), func, casted_args); + builder.create(op->getLoc(), call.getResults()); } // Transform a functional IfOp to a region based IfRegionOp. @@ -69,32 +84,47 @@ LogicalResult ConvertIfOp(IfOp if_op) { auto if_region = OpBuilder(if_op).create( if_op.getLoc(), if_op.getResultTypes(), if_op.cond(), if_op.is_stateless()); + CopyUnderscoredAttributes(if_op, if_region); - // Insert call to the given function into the 'region'. - auto create_region_with_call = [&if_op](FlatSymbolRefAttr symbol, - Region& region) { - OpBuilder builder(region); - builder.createBlock(®ion); - auto func = if_op.getParentOfType().lookupSymbol( - symbol.getValue()); - auto call = CreateCall(if_op.getLoc(), if_op.input(), func, &builder); - builder.create(if_op.getLoc(), call.getResults()); - }; - - create_region_with_call(if_op.then_branchAttr(), if_region.then_branch()); - create_region_with_call(if_op.else_branchAttr(), if_region.else_branch()); - + CreateCall(if_op, if_op.then_func(), + /*caller_region=*/if_region.then_branch(), if_op.input(), + /*use_region_args=*/false); + CreateCall(if_op, if_op.else_func(), + /*caller_region=*/if_region.else_branch(), if_op.input(), + /*use_region_args=*/false); if_op.replaceAllUsesWith(if_region.getResults()); if_op.erase(); return success(); } +LogicalResult ConvertWhileOp(WhileOp while_op) { + auto while_region = OpBuilder(while_op).create( + while_op.getLoc(), while_op.getResultTypes(), while_op.input(), + while_op.is_stateless(), while_op.parallel_iterations()); + CopyUnderscoredAttributes(while_op, while_region); + + CreateCall(while_op, while_op.cond_func(), + /*caller_region=*/while_region.cond(), while_op.input(), + /*use_region_args=*/true); + CreateCall(while_op, while_op.body_func(), + /*caller_region=*/while_region.body(), while_op.input(), + /*use_region_args=*/true); + while_op.replaceAllUsesWith(while_region.getResults()); + while_op.erase(); + return success(); +} + void FunctionalControlFlowToRegions::runOnOperation() { ModuleOp module = getOperation(); auto result = module.walk([](Operation* op) { if (IfOp if_op = llvm::dyn_cast(op)) { if (failed(ConvertIfOp(if_op))) { - if_op.emitOpError() << " failed to convert to region form"; + op->emitOpError() << "failed to convert to region form"; + return WalkResult::interrupt(); + } + } else if (auto while_op = llvm::dyn_cast(op)) { + if (failed(ConvertWhileOp(while_op))) { + op->emitOpError() << "failed to convert to region form"; return WalkResult::interrupt(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc index e2090803c00..7563f606434 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -51,7 +51,7 @@ Status MlirGraphOptimizationPass::Run(const ConfigProto& config_proto, CreateLayoutOptimizationPipeline(pm, layout_optimization_options); // Prepare IR for exporting. - pm.addNestedPass(CreateBreakUpIslandsPass()); + pm.addPass(CreateBreakUpIslandsPass()); // In case of failure, the `diag_handler` converts MLIR errors emitted to the // MLIRContext into a tensorflow::Status. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index f4d3eda3e7e..859d3ffb23c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -19,15 +19,62 @@ limitations under the License. #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" namespace mlir { namespace tf_executor { +// Visits an op's operand if it is an output of an Operation in the same +// tf_executor.graph. +void VisitOpOperand(GraphOp graph, Value operand, + llvm::SmallPtrSetImpl* reachable_ops, + llvm::SmallVectorImpl* ops_to_visit) { + Operation* def = operand.getDefiningOp(); + if (def && def->getParentOp() == graph && reachable_ops->insert(def).second) { + // Op has not been visited, add to queue to visit later. + ops_to_visit->push_back(def); + } +} + +// Visits all operands of an op where each operand is an output of an Operation +// in the same tf_executor.graph. +void VisitOpOperands(GraphOp graph, Operation* op, + llvm::SmallPtrSetImpl* reachable_ops, + llvm::SmallVectorImpl* ops_to_visit) { + for (Value operand : op->getOperands()) + VisitOpOperand(graph, operand, reachable_ops, ops_to_visit); +} + +// Visits an op and it's associated operands. IslandOps are handled differently +// where it's regions op operands are also visited as values may be implicitly +// captured within. NextIterationSourceOp will also visit it's associated +// NextIterationSinkOp. +void VisitOp(GraphOp graph, Operation* op, + llvm::SmallPtrSetImpl* reachable_ops, + llvm::SmallVectorImpl* ops_to_visit) { + if (auto island = llvm::dyn_cast(op)) { + mlir::visitUsedValuesDefinedAbove( + island.body(), island.body(), [&](OpOperand* operand) { + VisitOpOperand(graph, operand->get(), reachable_ops, ops_to_visit); + }); + } + + VisitOpOperands(graph, op, reachable_ops, ops_to_visit); + + // If op is a `tf_executor.NextIteration.Source`, visit its associated + // `tf_executor.NextIteration.Sink` op. + if (auto source_op = llvm::dyn_cast(op)) { + Operation* sink_op = source_op.GetSink().getOperation(); + if (reachable_ops->insert(sink_op).second) ops_to_visit->push_back(sink_op); + } +} + // Prunes unreachable operations of a tf_executor.graph operation. void PruneGraph(GraphOp graph) { // A graph has a single block which forms a DAG: operations that aren't @@ -36,49 +83,23 @@ void PruneGraph(GraphOp graph) { llvm::SmallPtrSet reachable_ops; llvm::SmallVector ops_to_visit; - // Visit an op's operands if it is output of an Operation in same graph. - auto visit_op = [&](Operation* op) { - for (Value operand : op->getOperands()) { - Operation* def = operand.getDefiningOp(); - if (def && def->getParentOp() == graph && - reachable_ops.insert(def).second) { - // Op has not been visited, add to queue to visit later. - ops_to_visit.push_back(def); - } - } - }; - - // Visit `fetch` operands. - visit_op(graph.GetFetch()); + // Visit fetches first to create a starting point for ops that are reachable. + reachable_ops.insert(graph.GetFetch()); + VisitOpOperands(graph, graph.GetFetch(), &reachable_ops, &ops_to_visit); + // Visit transitive ops until no there are no reachable ops left that have not + // been visited. while (!ops_to_visit.empty()) { Operation* op = ops_to_visit.pop_back_val(); - if (llvm::isa(op)) { - // Visit island and island inner ops operands. - op->walk([&](Operation* inner_op) { visit_op(inner_op); }); - continue; - } else { - // Op is not an island, only visit its operands. - visit_op(op); - } - - // If op is a `tf_executor.NextIteration.Source`, visit its associated - // `tf_executor.NextIteration.Sink` op. - if (auto source_op = llvm::dyn_cast(op)) { - Operation* sink_op = source_op.GetSink().getOperation(); - if (reachable_ops.insert(sink_op).second) { - ops_to_visit.push_back(sink_op); - } - } + VisitOp(graph, op, &reachable_ops, &ops_to_visit); } - // Erase unreachable ops in reverse order. - for (Operation& op : llvm::make_early_inc_range( - llvm::drop_begin(llvm::reverse(graph.GetBody()), 1))) { - if (reachable_ops.find(&op) == reachable_ops.end()) { - op.erase(); - } - } + // Erase unreachable ops in reverse order so references don't need to be + // dropped before removing an op. Going in reverse order will guarantee that + // when an op to be erased is reached, there are no users left. + for (Operation& op : + llvm::make_early_inc_range(llvm::reverse(graph.GetBody()))) + if (!reachable_ops.contains(&op)) op.erase(); } namespace { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/guarantee_all_funcs_one_use.cc b/tensorflow/compiler/mlir/tensorflow/transforms/guarantee_all_funcs_one_use.cc new file mode 100644 index 00000000000..776afd72ad5 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/guarantee_all_funcs_one_use.cc @@ -0,0 +1,121 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Utils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TF { + +namespace { + +// Clones FuncOp's until they have a single use only (or no users). +// +// The tf-shape-inference pass doesn't support functions that have more than +// a single use. But some real code from frontends does end up creating code +// like that. For example, the same LSTM cell function or loop body function +// will be reused. +// +// This pass clones functions as needed to establish the invariant that all +// functions have a single use. This can in principle cause exponential code +// size bloat, and should in general be guided by a proper cost model. +// +// There are two factors which should be considered by a principled replacement +// to this pass: +// +// 1. TF currently relies on "sufficiently good shape inference" for +// correctness so for now the cost of doing this seems acceptable since +// pathological cases haven't hit us yet. +// +// 2. Cloning functions can help by allowing code to be specialized (much as +// inlining does). In fact, tf-shape-inference attempts to do specialization +// of callees which is difficult if callees have multiple uses. +class GuaranteeAllFuncsOneUse + : public PassWrapper> { + public: + void runOnOperation() override { + if (failed(Run())) { + signalPassFailure(); + } + } + + LogicalResult Run() { + auto module = getOperation(); + + // Overall strategy: + // Fixed point iteration, iteratively applying a rule that clones + // any FuncOp with more than one use to eliminate its uses. + + SymbolTable symbol_table(module); + bool made_changes = false; + // This value needs to be low enough to actually stop compilation in a + // reasonable time, but not too low that it blocks real programs. + // This number was chosen semi-randomly. + const int k_max_clones = 1000; + int num_clones = 0; + do { + made_changes = false; + for (auto func : llvm::make_early_inc_range(module.getOps())) { + auto uses_optional = symbol_table.getSymbolUses(func, module); + if (!uses_optional.hasValue()) { + return func.emitError() << "could not walk uses of func"; + } + auto &uses = *uses_optional; + if (llvm::size(uses) <= 1) { + continue; + } + // At this point, we know we are going to change the module. + made_changes = true; + for (const SymbolTable::SymbolUse &use : llvm::drop_begin(uses, 1)) { + if (num_clones++ > k_max_clones) { + return func.emitError() + << "reached cloning limit (likely recursive call graph or " + "repeated diamond-like call structure " + "or just very large program)"; + } + auto new_func = func.clone(); + symbol_table.insert(new_func); + new_func.setVisibility(SymbolTable::Visibility::Private); + if (failed(symbol_table.replaceAllSymbolUses(func, new_func.getName(), + use.getUser()))) { + return func.emitError() << "could not replace symbol use"; + } + } + } + } while (made_changes); + + return success(); + } +}; + +} // namespace + +std::unique_ptr> CreateGuaranteeAllFuncsOneUsePass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-guarantee-all-funcs-one-use", + "Guarantee all FuncOp's have only a single use."); + +} // namespace TF + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc new file mode 100644 index 00000000000..615ca26012e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import.cc @@ -0,0 +1,134 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { +namespace { + +static constexpr int kTextFileIndex_WholeLine = -2; +static constexpr int kTextFileIndex_LineNumber = -1; + +// InitTextFileToImportPass converts InitializeTableFromTextFileV2Op to the +// corresponding LookupTableImportV2Op if possible. +class InitTextFileToImportPass + : public mlir::PassWrapper { + public: + explicit InitTextFileToImportPass() {} + + private: + void runOnFunction() override; +}; + +class ConvertInitializeTableFromTextFileV2 + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InitializeTableFromTextFileV2Op op, + PatternRewriter& rewriter) const override { + // Now, this pattern matching only supports the following case, which is + // commonly used among inference use cases: + // + // tf.lookup.TextFileInitializer( + // "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE, + // tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" ") + // + // In the above case, the delimiter will be not used since the key is just a + // whole line and value is a line number. + if (op.key_index() != kTextFileIndex_WholeLine || + op.value_index() != kTextFileIndex_LineNumber || + op.vocab_size() != -1) { + return failure(); + } + + // Try to find filename from constant op. + DenseStringElementsAttr filename_attr; + if (!matchPattern(op.filename().getDefiningOp(), + m_Constant(&filename_attr))) { + return failure(); + } + StringRef filename = filename_attr.getRawStringData()[0]; + + // Read the content of the file. + std::string error_message; + auto file = openInputFile(filename, &error_message); + if (!file) { + return op.emitOpError("failed to open vocabulary file") + << " (" << filename.str() << "): " << error_message; + } + + // Splits into lines. + SmallVector lines; + file->getBuffer().split(lines, "\n", -1, false); + + // Map each line to line number, starting from zero. + SmallVector line_nums; + line_nums.resize(lines.size()); + std::iota(line_nums.begin(), line_nums.end(), 0); + + // Create constant ops for keys an values. + Value key_constant_tensor = rewriter.create( + op.getLoc(), + DenseStringElementsAttr::get( + RankedTensorType::get(static_cast(lines.size()), + StringType::get(rewriter.getContext())), + lines)); + + Value value_constant_tensor = rewriter.create( + op.getLoc(), rewriter.getI64TensorAttr(line_nums)); + + // Replace the given op with LookupTableImportV2Op. + rewriter.create(op.getLoc(), op.table_handle(), + key_constant_tensor, + value_constant_tensor); + rewriter.eraseOp(op); + return success(); + } +}; + +void InitTextFileToImportPass::runOnFunction() { + OwningRewritePatternList patterns; + MLIRContext* context = &getContext(); + FuncOp func = getFunction(); + + patterns.insert(context); + applyPatternsAndFoldGreedily(func, patterns); +} + +} // namespace + +// Replace InitializeTableFromTextFileV2Ops with LookupTableImportV2Ops. +std::unique_ptr> CreateInitTextFileToImportPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-init-text-file-to-import", + "convert InitializeTableFromTextFileV2 ops to LookupTableImportV2Op to " + "remove the dependency on asset files"); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc new file mode 100644 index 00000000000..96a04fa6eeb --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/init_text_file_to_import_test_pass.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/Casting.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TF { +namespace { + +// InitTextFileToImportTestPass generates a temporary file and run the +// InitTextFileToImportPass for testing purpose. +class InitTextFileToImportTestPass + : public mlir::PassWrapper> { + public: + explicit InitTextFileToImportTestPass() {} + + private: + void runOnOperation() override; +}; + +void InitTextFileToImportTestPass::runOnOperation() { + ModuleOp module = getOperation(); + + // Create a temporary vocab file. + int fd; + SmallString<256> filename; + std::error_code error_code = + llvm::sys::fs::createTemporaryFile("text", "vocab", fd, filename); + if (error_code) return signalPassFailure(); + + llvm::ToolOutputFile temp_file(filename, fd); + const char* dictionary_in_lines = + "apple\n" + "banana\n" + "grape"; + temp_file.os() << dictionary_in_lines; + temp_file.os().flush(); + + // Replace filename constant ops to use the temporary file. + MLIRContext* context = &getContext(); + + for (FuncOp func : module.getOps()) { + llvm::SmallVector constant_ops(func.getOps()); + for (auto op : constant_ops) { + ShapedType shaped_type = + RankedTensorType::get({1}, StringType::get(context)); + + DenseStringElementsAttr attr; + if (!matchPattern(op.getOperation(), m_Constant(&attr))) { + continue; + } + + ArrayRef values = attr.getRawStringData(); + if (values.size() != 1 || values[0] != "%FILE_PLACEHOLDER") { + continue; + } + + op.valueAttr(DenseStringElementsAttr::get(shaped_type, {filename})); + } + } + + // Run the lowering pass. + PassManager pm(context); + pm.addPass(CreateInitTextFileToImportPass()); + if (failed(pm.run(module))) return signalPassFailure(); +} + +} // namespace + +static PassRegistration pass( + "tf-init-text-file-to-import-test", + "generate a temporary file and invoke InitTextFileToImportPass"); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc index bce18c0b4b7..9f67a3e7e71 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc @@ -106,8 +106,8 @@ LogicalResult HoistOpsAndAnnotateWithDevice(const Dialect* tf_dialect, void LaunchToDeviceAttributePass::runOnFunction() { const Dialect* tf_dialect = getContext().getRegisteredDialect("tf"); if (!tf_dialect) { - signalPassFailure(); getFunction().emitError() << "'tf' dialect is not registered"; + return signalPassFailure(); } auto result = getFunction().walk([&](tf_device::LaunchOp launch) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc index c263dcc75d1..ad241ef9488 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/legalize_hlo.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/core/framework/kernel_shape_util.h" namespace mlir { @@ -744,9 +745,7 @@ void LegalizeHloToTf::runOnFunction() { // Add legalization patterns to the list. OwningRewritePatternList patterns; - populateWithGenerated(&context, &patterns); - patterns.insert(&context); + PopulateLegalizeHloToTfPatterns(&patterns, &context); ConversionTarget target(context); target.addLegalDialect(); @@ -762,6 +761,13 @@ static PassRegistration pass( } // end namespace +void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList *patterns, + MLIRContext *context) { + populateWithGenerated(context, patterns); + patterns->insert(context); +} + std::unique_ptr> CreateLegalizeHloToTfPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index c0de6f557ab..483c84b3e80 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -113,12 +113,42 @@ Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) { // Lowers AddN op to a sequence of AddV2 ops to accumulate operands. // +// Note that to improve the parallelism, AddN op uses tree-based reduction. +// For example, tf.AddN([0, 1, 2, 3, 4]) behaves as follows: +// +// 0 1 2 3 4 +// | | | | | +// ------- ------- | +// | | | +// 5 6 | +// | | | +// ------------- | +// | | +// 7 | +// | | +// ---------------- +// | +// 8 +// +// Example: +// // %result = "tf.AddN"(%0, %1, %2) // // is lowered to: // -// %sum_0 = "tf.AddV2"(%0, %1) -// %result = "tf.AddV2"(%sum_0, %2) +// %sum0 = "tf.AddV2"(%0, %1) +// %result = "tf.AddV2"(%sum0, %2) +// +// While +// +// %result = "tf.AddN"(%0, %1, %2, %3, %4) +// +// is lowered to: +// +// %sum0 = "tf.AddV2"(%0, %1) +// %sum1 = "tf.AddV2"(%2, %3) +// %sum2 = "tf.AddV2"(%sum0, %sum1) +// %result = "tf.AddV2"(%sum2, %4) // class LowerAddNOp : public OpRewritePattern { public: @@ -131,14 +161,23 @@ class LowerAddNOp : public OpRewritePattern { // support variant type so variant types require special handling. if (getElementTypeOrSelf(op.getType()).isa()) return failure(); - // TODO(hinsu): Improve parallelism by splitting operands in two halves and - // accumulating them first. - Value result = *op.inputs().begin(); - for (Value operand : llvm::drop_begin(op.inputs(), 1)) { - result = rewriter.create(op.getLoc(), result, operand); + llvm::SmallVector operands(op.inputs().begin(), + op.inputs().end()); + + int64_t n = operands.size(); + // Keep doing tree-based reduction when there are more than one operand. + while (n > 1) { + for (int64_t i = 0; i < n; i += 2) { + // Add two adjacent operands if applicable. + operands[i / 2] = (i + 1 < n) + ? rewriter.create( + op.getLoc(), operands[i], operands[i + 1]) + : operands[i]; + } + n = (n + 1) / 2; } - rewriter.replaceOp(op, result); + rewriter.replaceOp(op, operands[0]); return success(); } }; @@ -344,12 +383,56 @@ class LowerPackOp : public OpRewritePattern { } }; +// Lowers `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness hints, +// since we currently don't have an implementation that can use this +// information. Adds appropriate casts where necessary to align element types +// of operands and result for `TF::MatMulOp`. +class LowerSparseMatMulOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SparseMatMulOp op, + PatternRewriter &rewriter) const override { + // Result type must be f32 for applying the pattern (currently this is + // required by the op anyway but this might change). + if (!op.product().getType().cast().getElementType().isF32()) { + return failure(); + } + MLIRContext *context = rewriter.getContext(); + llvm::SmallVector operands{op.a(), op.b()}; + for (Value &operand : operands) { + TensorType tensor_type = operand.getType().cast(); + Type element_type = tensor_type.getElementType(); + if (element_type.isF32()) continue; + // Element type can either be f32 or bf16 for `SparseMatMulOp` so it + // must be bf16 here. + assert(element_type.isBF16()); + Type tensor_type_f32; + if (tensor_type.hasRank()) { + tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(), + FloatType::getF32(context)); + } else { + tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context)); + } + // Add cast to f32 to conform with element type of result. + operand = + rewriter.create(op.getLoc(), tensor_type_f32, operand); + } + Value result = rewriter.create( + op.getLoc(), op.product().getType(), operands[0], operands[1], + op.transpose_a(), op.transpose_b()); + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + } // namespace void PopulateLoweringTFPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { patterns->insert(context); + LowerPackOp, LowerSparseMatMulOp>(context); populateWithGenerated(context, patterns); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc new file mode 100644 index 00000000000..ece26dca416 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -0,0 +1,177 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" + +namespace mlir { +namespace TFDevice { + +namespace { + +constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; + +// This pass marks unsupported ops in a device cluster with +// `_xla_outside_compilation` attribute so the operations will run on the host +// instead of the device. Unsupported ops are ops that can not be code +// generated to run on the device for the cluster. +struct MarkOpsForOutsideCompilation + : public PassWrapper> { + void runOnOperation() override; +}; + +// TODO(b/159128666): Check the control flow legalization passes instead once +// added. +void AddSupportedControlFlowOps(MLIRContext* context, + llvm::DenseSet* supported_ops) { + supported_ops->insert( + OperationName(TF::IfRegionOp::getOperationName(), context)); + supported_ops->insert( + OperationName(TF::WhileRegionOp::getOperationName(), context)); + supported_ops->insert( + OperationName(TF::YieldOp::getOperationName(), context)); +} + +// These embedding ops are rewritten when running TPUCompileOp. +void AddRewrittenEmbeddingOps(MLIRContext* context, + llvm::DenseSet* supported_ops) { + supported_ops->insert(OperationName( + TF::RecvTPUEmbeddingActivationsOp::getOperationName(), context)); + supported_ops->insert(OperationName( + TF::SendTPUEmbeddingGradientsOp::getOperationName(), context)); +} + +bool HasStringOperand(Operation& op) { + for (auto operand : op.getOperands()) { + if (getElementTypeOrSelf(operand).isa()) return true; + } + return false; +} + +bool HasStringResult(Operation& op) { + for (auto result : op.getResults()) { + if (getElementTypeOrSelf(result).isa()) return true; + } + return false; +} + +bool MatchesPattern(Operation& op, + const llvm::DenseSet& supported_ops) { + return (supported_ops.contains(op.getName())); +} + +// Checks if the op is supported inside of a device cluster. Ops not +// in `tf_dialect` are considered supported. +bool IsSupportedOp(Operation& op, + const llvm::DenseSet& supported_ops, + const Dialect* tf_dialect) { + if (op.getDialect() != tf_dialect) + return true; + else + return !HasStringOperand(op) && !HasStringResult(op) && + (MatchesPattern(op, supported_ops) || + mhlo::IsOpAllowedTf2XlaFallback(&op)); +} + +// Checks all regions of `op` for captured string operands. +bool HasCapturedStringOperand(Operation* op) { + bool string_operand = false; + for (auto& region : op->getRegions()) { + mlir::visitUsedValuesDefinedAbove( + region, region, [&](mlir::OpOperand* operand) { + if (getElementTypeOrSelf(operand->get()).isa()) + string_operand = true; + }); + if (string_operand) return string_operand; + } + return string_operand; +} + +// Marks uncompilable ops that are in `tf_dialect` for outside compilation. +LogicalResult MarkUncompilableOps( + const Dialect* tf_dialect, Block* block, + llvm::DenseSet& supported_ops) { + block->walk([&](Operation* op) { + if (!IsSupportedOp(*op, supported_ops, tf_dialect)) { + op->setAttr(kXlaOutsideCompilationAttr, + StringAttr::get("auto", op->getContext())); + } + if (llvm::isa(op)) { + if (HasCapturedStringOperand(op)) { + op->setAttr(kXlaOutsideCompilationAttr, + StringAttr::get("auto", op->getContext())); + } + } + }); + return success(); +} + +void MarkOpsForOutsideCompilation::runOnOperation() { + auto module = getOperation(); + const Dialect* tf_dialect = getContext().getRegisteredDialect("tf"); + if (!tf_dialect) { + getOperation().emitError() << "'tf' dialect is not registered"; + return signalPassFailure(); + } + OwningRewritePatternList patterns; + mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns); + + // `supported_ops` contains the name of all of the ops that can potentially be + // lowered into HLO on the device. This doesn't always mean that the op can + // be lowered in the future passes but if the op is not in this set, it can't + // be lowered in a subsequent pass. + llvm::DenseSet supported_ops; + for (auto& pattern : patterns) { + supported_ops.insert(*pattern->getRootKind()); + } + AddSupportedControlFlowOps(module.getContext(), &supported_ops); + AddRewrittenEmbeddingOps(module.getContext(), &supported_ops); + + auto result = module.walk([&](tf_device::ClusterOp cluster) { + if (failed( + MarkUncompilableOps(tf_dialect, &cluster.GetBody(), supported_ops))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) return signalPassFailure(); +} + +} // namespace + +std::unique_ptr> +CreateMarkOpsForOutsideCompilationPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-mark-ops-for-outside-compilation", + "Marks unsupported ops a device cluster for outside compilation."); + +} // namespace TFDevice +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index 67a6c8dd6dd..6fee693554e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -68,9 +68,8 @@ bool IsResource(Value value) { return IsResourceType(value.getType()); } class ResourceAnalyzer { public: explicit ResourceAnalyzer(ModuleOp module) { - SymbolTable symbol_table(module); for (auto func : module.getOps()) { - AnalyzeFunc(func, symbol_table); + AnalyzeFunc(func); } } @@ -89,7 +88,7 @@ class ResourceAnalyzer { // written". Do this recursively across the chain of funcs via call or control // flow ops. // TODO(ashwinm): Move to iterative traversal. - LogicalResult AnalyzeFunc(FuncOp func, const SymbolTable& symbol_table) { + LogicalResult AnalyzeFunc(FuncOp func) { // Avoid infinite recursion. if (!discovered_.insert(func).second) { return success(); @@ -104,24 +103,20 @@ class ResourceAnalyzer { return; } if (auto call = dyn_cast(op)) { - if (auto sym = op->getAttrOfType("f")) { - PropagatePotentiallyWrittenUpFromCallee( - sym.cast().getValue(), call.getArgOperands(), - symbol_table); + if (auto func = dyn_cast(call.resolveCallable())) { + PropagatePotentiallyWrittenUpFromCallee(func, call.getArgOperands()); } return; } if (auto if_op = dyn_cast(op)) { - for (auto callee : {if_op.then_branch(), if_op.else_branch()}) { - PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input(), - symbol_table); + for (auto callee : {if_op.then_func(), if_op.else_func()}) { + PropagatePotentiallyWrittenUpFromCallee(callee, if_op.input()); } return; } if (auto while_op = dyn_cast(op)) { - for (auto callee : {while_op.cond(), while_op.body()}) { - PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input(), - symbol_table); + for (auto callee : {while_op.cond_func(), while_op.body_func()}) { + PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input()); } return; } @@ -149,15 +144,13 @@ class ResourceAnalyzer { }); } - // Given a funcOp associated with the callee and operands from the + // Given a FuncOp associated with the callee and operands from the // corresponding callOp, propagate the potentially written decision to the // callOp's operands, if the corresponding func's arguments are potentially // written resources. void PropagatePotentiallyWrittenUpFromCallee( - StringRef callee, Operation::operand_range propagate_to, - const SymbolTable& symbol_table) { - auto func = symbol_table.lookup(callee); - AnalyzeFunc(func, symbol_table); + FuncOp func, Operation::operand_range propagate_to) { + AnalyzeFunc(func); for (auto t : llvm::zip(func.getArguments(), propagate_to)) { if (!IsResource(std::get<0>(t))) { continue; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc index c13d7de754e..1332c8b6e59 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallel_execute_to_islands.cc @@ -71,6 +71,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -111,8 +112,8 @@ LogicalResult ExpandParallelExecuteToIslands( // executed. llvm::SetVector region_inputs; getUsedValuesDefinedAbove(*execute_region, region_inputs); - llvm::SmallVector execution_control_inputs; - if (region_inputs.empty()) + llvm::SmallVector execution_control_inputs; + if (region_inputs.empty() && input_sink_island) execution_control_inputs.emplace_back(input_sink_island.control()); // Collect result types and operands. @@ -147,13 +148,22 @@ tf_executor::IslandOp CreateInputBarrierIsland( OpBuilder* builder, tf_executor::IslandOp island_op) { builder->setInsertionPoint(island_op); - llvm::SetVector island_inputs; - getUsedValuesDefinedAbove(island_op.body(), island_inputs); + llvm::SetVector all_inputs; + getUsedValuesDefinedAbove(island_op.body(), all_inputs); + // Filter out values that are arguments and doesn't need to be part of the + // entry barrier. + llvm::SmallVector island_inputs; llvm::SmallVector input_types; - input_types.reserve(island_inputs.size()); - for (const auto& input_val : island_inputs) - input_types.emplace_back(input_val.getType()); + island_inputs.reserve(all_inputs.size()); + input_types.reserve(all_inputs.size()); + for (Value val : all_inputs) { + if (!val.isa()) { + island_inputs.push_back(val); + input_types.push_back(val.getType()); + } + } + if (island_inputs.empty() && island_op.controlInputs().empty()) return {}; // Create new island for that forwards all inputs. auto control_type = tf_executor::ControlType::get(island_op.getContext()); @@ -190,7 +200,7 @@ tf_executor::IslandOp CreateOutputBarrierIsland( builder->setInsertionPoint(island_op); auto island_output_sink = builder->create( island_op.getLoc(), llvm::to_vector<8>(island_op.getResultTypes()), - island_operands, llvm::ArrayRef{}); + island_operands); island_output_sink.body().push_back(new Block); return island_output_sink; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc new file mode 100644 index 00000000000..527af0934ea --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/parallelize_embedding_params_ops_pass.cc @@ -0,0 +1,152 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation parallelizes TPU embedding params assigned to different +// shards using the parallel execute op. This is useful to avoid introducing +// control dependency between these ops that are known to be independent. + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h" + +namespace mlir { +namespace TFDevice { + +namespace { + +struct ParallelizeEmbeddingParamsOpsPass + : public PassWrapper { + void runOnFunction() override; +}; + +bool IsLoadTPUEmbeddingParmasOp(Operation& op) { + static const auto* algorithms = []() { + auto* algorithms = new llvm::SmallSet(); + for (tensorflow::tpu::OptimizationAlgorithm alg : + tensorflow::tpu::GetOptimizationAlgorithms()) { + const auto alg_name = tensorflow::tpu::GetOptimizationAlgorithmName(alg); + algorithms->insert(alg_name); + } + return algorithms; + }(); + StringRef op_name = op.getName().getStringRef(); + return op_name.consume_front("tf.LoadTPUEmbedding") && + op_name.consume_back("Parameters") && + algorithms->contains(op_name.str()); +} + +static LogicalResult RunOnIsland(tf_executor::IslandOp island) { + Block* block = island.getBody(); + + // Map from op to the id of the shard it is assigned for ops that can execute + // in parallel across shards. + llvm::SmallMapVector assigned_shard; + llvm::SmallVector resources; + llvm::SmallSet shard_ids; + for (Operation& op : llvm::reverse(*block)) { + int64_t shard = -1; + if (IsLoadTPUEmbeddingParmasOp(op)) { + auto shard_id = op.getAttrOfType("shard_id"); + if (!shard_id) { + return op.emitOpError("requires 'shard_id' integer attribute"); + } + shard = shard_id.getInt(); + shard_ids.insert(shard); + } else if (auto read_op = llvm::dyn_cast(op)) { + if (assigned_shard.empty()) continue; + + for (Operation* user : op.getUsers()) { + auto iter = assigned_shard.find(user); + if (iter == assigned_shard.end() || + (shard != -1 && shard != iter->second)) { + shard = -1; + break; + } + shard = iter->second; + } + if (shard != -1) resources.push_back(read_op.resource()); + } + + if (shard != -1) assigned_shard.insert(std::make_pair(&op, shard)); + } + + // No transformations are required. + int num_shards = shard_ids.size(); + if (num_shards <= 1) return success(); + + // If the resources are used for ops other than read variable op, then moving + // read variable ops to the parallel_execute may not preserve the semantics. + for (Value resource : resources) { + for (Operation* user : resource.getUsers()) + if (!llvm::isa(*user)) return success(); + } + + // Create parallel_execute op at the end of the block and move operations + // to their corresponding shard. + auto builder = OpBuilder::atBlockTerminator(block); + auto parallel_execute_op = builder.create( + island.getLoc(), num_shards, llvm::ArrayRef()); + for (int shard_id = 0; shard_id < num_shards; ++shard_id) { + mlir::Block& b = parallel_execute_op.GetRegionBlockWithIndex(shard_id); + builder.setInsertionPointToStart(&b); + builder.create(island.getLoc()); + } + + for (auto op_shard : assigned_shard) { + int64_t shard = op_shard.second; + if (shard >= num_shards) { + return island.emitOpError( + "load tpu embedding ops require continuous range of shards"); + } + mlir::Block& b = parallel_execute_op.GetRegionBlockWithIndex(shard); + op_shard.first->moveBefore(&b, b.begin()); + } + return success(); +} + +void ParallelizeEmbeddingParamsOpsPass::runOnFunction() { + getFunction().walk([&](tf_executor::IslandOp island) { + if (failed(RunOnIsland(island))) { + signalPassFailure(); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); +} + +} // namespace + +std::unique_ptr> +CreateParallelizeEmbeddingParamsOpsPass() { + return std::make_unique(); +} +} // namespace TFDevice +} // namespace mlir + +static mlir::PassRegistration + pass("tf-parallize-embedding-params-ops", + "Parallelizes TPU embedding params assigned to different shards using " + "the parallel_execte op"); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 5af8a0195a4..3be6c9e1a70 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -18,13 +18,15 @@ limitations under the License. #include +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project namespace mlir { // Creates a pass that breaks up an island with multiple ops into multiple // islands, each with a single op. -std::unique_ptr> CreateBreakUpIslandsPass(); +std::unique_ptr> CreateBreakUpIslandsPass(); // Creates a pass that converts mlir functions consisting of mlir ops into a // tf_executor dialect as a single island. @@ -58,6 +60,9 @@ std::unique_ptr> CreateMaterializePassthroughOpPass(); // Performs Shape Inference on the TensorFlow dialect using the global registry. std::unique_ptr> CreateTFShapeInferencePass(); +// Guarantee that all FuncOp's have a single use. +std::unique_ptr> CreateGuaranteeAllFuncsOneUsePass(); + // Optional pass which will unroll BatchMatMul and use only MatMul std::unique_ptr> CreateUnrollBatchMatMulPassPass(); @@ -148,6 +153,10 @@ CreateTensorArrayOpsDecompositionPass(); // Create a pass that legalize HLO to TF dialect. std::unique_ptr> CreateLegalizeHloToTfPass(); +// Addds the HLO to TF rewrite patterns to the specified pattern list. +void PopulateLegalizeHloToTfPatterns(OwningRewritePatternList* patterns, + MLIRContext* context); + // Matches sequence of ops to TensorFlow fused kernels. This pass should not be // generally used beyond exporting to runtimes that supports these ops. In the // future these fusions may be codegen'd automatically. @@ -155,6 +164,10 @@ std::unique_ptr> CreateFusedKernelMatcherPass(); // Creates function pass to select device index/fold tf.DeviceIndex. std::unique_ptr> CreateDeviceIndexSelectorPass(); + +// Creates function pass to replace InitializeTableFromTextFileV2Ops with +// LookupTableImportV2Op ops. +std::unique_ptr> CreateInitTextFileToImportPass(); } // namespace TF namespace tf_executor { @@ -226,17 +239,27 @@ std::unique_ptr> CreateReplicateInvariantOpHoistingPass(); // Creates a pass that forms replica `tf_executor.island` from a single // `tf_device.replicate` island. -std::unique_ptr> CreateReplicateToIslandPass(); +std::unique_ptr> CreateReplicateToIslandPass(); // Creates a pass that creates `tf_executor.island` from a single // `tf_device.parallel_execute` island. std::unique_ptr> CreateParallelExecuteToIslandsPass(); +// Create a pass to parallelize TPU embedding params assigned to different +// shards using the parallel_execte op. +std::unique_ptr> +CreateParallelizeEmbeddingParamsOpsPass(); + // Creates a pass that annotates whether a LaunchFuncOp's parameters have the // same data across replicas. std::unique_ptr> CreateAnnotateParameterReplicationPass(); +// Creates a pass that marks unsupported ops in device cluster for outside +// compilation. +std::unique_ptr> +CreateMarkOpsForOutsideCompilationPass(); + // Creates a pass that hoists a `tf_device.launch` body and assigns a `device` // attribute to each TensorFlow dialect op in the body based on the `device` // attribute on the `tf_device.launch`. @@ -250,7 +273,7 @@ std::unique_ptr> CreateTPUClusterFormationPass(); // Creates a pass that allows TPU program inputs to have layouts determined at // run time. -std::unique_ptr> CreateTPUDynamicLayoutPass(); +std::unique_ptr> CreateTPUDynamicLayoutPass(); // Creates a pass that remaps and assigns padding map from a // `tf_device.launch_func` `padding_map` attribute to its encapsulated function. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index 961287b0b1f..89910d6b3a5 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -304,7 +304,7 @@ LogicalResult PromoteResourcesToArguments( continue; } - const auto index = resource_and_index.index(); + const int64_t index = resource_and_index.index(); const bool is_var_handle = index >= var_handles_start_idx; if (resource.write) { if (!is_var_handle || resource.read) { @@ -342,7 +342,8 @@ LogicalResult PromoteResourcesToArguments( } // Rewrite return if there are variable writes. - if (return_operands.size() > num_results_before) { + const int return_operands_size = return_operands.size(); + if (return_operands_size > num_results_before) { builder.create(return_op.getLoc(), return_operands); return_op.erase(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc index 5fc35361bca..104f11e0cc0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc @@ -65,8 +65,15 @@ class ConvertReadonlyReferenceVariablesToResourceVariablesPass StringRef GetNodeNameFromClassAttr(Operation *op) { ArrayAttr classes_attr = op->getAttrOfType(kClassAttr); if (!classes_attr) { - op->emitOpError() << "has no '_class' attribute"; - return StringRef(); + // Attampt to parse "_class" from the IdentityOp that follows VariableV2. + // For read-only reference variables, IdentityOp should be the only user of + // VariableV2. + auto identity_op = op->getUsers().begin(); + classes_attr = identity_op->getAttrOfType(kClassAttr); + if (!classes_attr) { + op->emitOpError() << "has no '_class' attribute"; + return StringRef(); + } } StringRef result; @@ -153,7 +160,7 @@ void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() { builder.setInsertionPoint(user); ReadVariableOp read_variable_op = builder.create( user->getLoc(), ArrayRef{tensor_type}, - ArrayRef{var_handle_op}, ArrayRef{}); + ArrayRef{var_handle_op}); user->getResult(0).replaceAllUsesWith(read_variable_op.getResult()); user->erase(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index ca0467942ca..ba876e08fbb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -15,9 +15,11 @@ limitations under the License. // This transformation pass transforms region bases control flow operations in // the TensorFlow dialect to their functional counterparts, i.e., -// tf.IfRegion -> tf.If +// tf.IfRegion -> tf.If and tf.WhileRegion -> tf.While #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -34,8 +36,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#define DEBUG_TYPE "tf-region-cf-to-functional" + namespace mlir { namespace TF { @@ -48,6 +53,7 @@ struct RegionControlFlowToFunctional private: LogicalResult ConvertIfOp(IfRegionOp if_region); + LogicalResult ConvertWhileOp(WhileRegionOp while_region); // Get unique name by using the loc to name mapping. std::string GetName(Operation* op, StringRef suffix); @@ -61,20 +67,20 @@ std::string RegionControlFlowToFunctional::GetName(Operation* op, return (mapper.GetUniqueName(op) + suffix).str(); } -// Returns all the external values referenced from the given set of regions. If -// the external value is a constant, sink it into the region instead (and do not +// Returns all the external values referenced from the given regions. If the +// external value is a constant, sink it into the region instead (and do not // add it to the returned vector). -llvm::SmallVector CollectExternValues(ArrayRef regions) { - llvm::SetVector extern_values_set; +llvm::SmallVector CollectExternValues(Region& first, Region& second) { + llvm::SetVector extern_values; - for (auto region : regions) { + for (Region* region : {&first, &second}) { llvm::SetVector region_extern_values; getUsedValuesDefinedAbove(*region, region_extern_values); // Sink down constants into the functions. for (auto extern_value : region_extern_values) { if (!matchPattern(extern_value, m_Constant())) { - extern_values_set.insert(extern_value); + extern_values.insert(extern_value); continue; } // Add constant at start of region. @@ -85,28 +91,43 @@ llvm::SmallVector CollectExternValues(ArrayRef regions) { } } - return {extern_values_set.begin(), extern_values_set.end()}; + return llvm::to_vector<4>(extern_values); } // Extracts the contents of a region with a single block into a new function. // `extern_values` is the set of external values that the region refers to. // -// Any inputs to the terminator of the region are converted to return values of -// the function. If any of these values is not exact type as the function's -// return type, appropriate cast operations will be inserted -void ExtractSingleBlockRegion(Region& region, FunctionType type, StringRef name, +// Inputs to the terminator of the region are converted to return values of +// the function. If `extern_values_passthrough` is true, all the extern values +// are also added as return values from the function +void ExtractSingleBlockRegion(Region& region, StringRef name, llvm::SmallVectorImpl& extern_values, - llvm::SmallVectorImpl& worklist) { + llvm::SmallVectorImpl& worklist, + bool extern_values_passthrough) { ModuleOp module = region.getParentOfType(); auto builder = OpBuilder::atBlockBegin(module.getBody()); auto loc = region.getParentOp()->getLoc(); + Block& entry = region.front(); + int num_region_arguments = entry.getNumArguments(); + Operation* terminator = entry.getTerminator(); + + // Build the function type. Region arguments and extern values together + // become the function arguments, with region arguments going first. + auto input_types = llvm::to_vector<4>(entry.getArgumentTypes()); + for (auto input : extern_values) input_types.push_back(input.getType()); + + // Terminator operands and pass through extern values (if enabled) together + // become the function return values. + auto return_types = llvm::to_vector<4>(terminator->getOperandTypes()); + if (extern_values_passthrough) + for (auto input : extern_values) return_types.push_back(input.getType()); + + auto type = FunctionType::get(input_types, return_types, region.getContext()); // Create new function and extract region body into the function. - auto outlined_func = - builder.create(loc, name, type, ArrayRef{}); - - outlined_func.getBody().takeBody(region); + auto outlined_func = builder.create(loc, name, type); Region& func_region = outlined_func.getBody(); + func_region.takeBody(region); Block& first_block = func_region.front(); // Replace all external uses with function arguments. @@ -115,27 +136,24 @@ void ExtractSingleBlockRegion(Region& region, FunctionType type, StringRef name, replaceAllUsesInRegionWith(it.value(), arg, func_region); } - // Replace the existing terminator with a return. - Operation* terminator = outlined_func.getBody().front().getTerminator(); - builder.setInsertionPoint(terminator); + // Function return values are all the terminator operands + pass through + // extern values (if enabled). + auto return_values = llvm::to_vector<4>(terminator->getOperands()); + if (extern_values_passthrough) + return_values.insert(return_values.end(), + first_block.args_begin() + num_region_arguments, + first_block.args_end()); - SmallVector return_values; - return_values.reserve(terminator->getNumOperands()); - for (auto it : llvm::enumerate(type.getResults())) { - Value ret_val = terminator->getOperand(it.index()); - // Add a cast operation if types do not match. - if (ret_val.getType() != it.value()) { - ret_val = - builder.create(terminator->getLoc(), it.value(), ret_val); - } - return_values.push_back(ret_val); - } + // Replace the existing terminator with a return. + terminator = first_block.getTerminator(); + builder.setInsertionPoint(terminator); builder.create(terminator->getLoc(), return_values); terminator->erase(); + outlined_func.setVisibility(FuncOp::Visibility::Private); // Add the outlined function to the worklist in case its body has - // IfRegion ops that need to converted. + // IfRegion or WhileRegion ops that need to converted. worklist.push_back(outlined_func); } @@ -170,17 +188,29 @@ llvm::Optional IsSingleCallRegion(Region& region) { return call; } -// Returns whether the arguments of the given call are same as the given list of -// arguments (after looking through cast ops). -bool MatchCallArgs(CallOp call, llvm::SmallVectorImpl& args) { - if (call.getNumOperands() != args.size()) return false; +using MatcherFn = function_ref; - for (auto it : llvm::enumerate(args)) { - Value arg = call.getOperand(it.index()); - if (auto cast = dyn_cast_or_null(arg.getDefiningOp())) - arg = cast.getOperand(); +// Returns whether the arguments of the given 2 calls are match (after looking +// through cast ops). `matcher` is the predicate used to check if two arguments +// match. +bool MatchCallArgs(CallOp first, CallOp second, MatcherFn matcher) { + if (first.getNumOperands() != second.getNumOperands()) return false; - if (arg != it.value()) return false; + Region& first_region = *first.getParentRegion(); + Region& second_region = *second.getParentRegion(); + + for (auto it : llvm::zip(first.getArgOperands(), second.getArgOperands())) { + // Get the defining Op, skipping over casts. + auto get_defining_op = [](Value value) { + while (llvm::isa_and_nonnull(value.getDefiningOp())) + value = cast(value.getDefiningOp()).getOperand(); + return value; + }; + Value first_arg = get_defining_op(std::get<0>(it)); + Value second_arg = get_defining_op(std::get<1>(it)); + + if (!matcher(first_arg, first_region, second_arg, second_region)) + return false; } return true; } @@ -193,11 +223,10 @@ struct TrivialTransformInfo { bool can_transform = false; // List of callee names (one for each region). - llvm::SmallVector callee_names; + llvm::SmallVector callee_names; - // List of arguments used in these call (each call uses the same arguments - // potentially through casts). - llvm::SmallVector call_args; + // Constructor will analyze the 2 regions. + TrivialTransformInfo(Region& first, Region& second, MatcherFn matcher); }; // Analyzes the given set of regions (attached to the same parent op) to check @@ -206,88 +235,62 @@ struct TrivialTransformInfo { // regions are single call regions and the all the calls have the same // arguments. // -// If this trivial transformation is possible, return the relevant information -// needed for the transformation (in `TrivialTransformInfo`), else indicate that -// a trivial transformation is not possible by setting `can_transform` false. -TrivialTransformInfo AnalyzeForTrivialTransform(ArrayRef regions) { - const TrivialTransformInfo cannot_transform; +// If such a trivial transformation is possible, stash the relevant information +// needed for the transformation, else indicate that a trivial transformation is +// not possible by setting `can_transform` to false. +TrivialTransformInfo::TrivialTransformInfo(Region& first, Region& second, + MatcherFn matcher) { + auto call0 = IsSingleCallRegion(first); + auto call1 = IsSingleCallRegion(second); + if (!call0 || !call1) return; - if (regions.empty()) return cannot_transform; + if (!MatchCallArgs(call0.getValue(), call1.getValue(), matcher)) return; - llvm::SmallVector calls; - calls.reserve(regions.size()); - - // Verify each region is a single call and collect these calls. - for (Region* region : regions) { - auto call = IsSingleCallRegion(*region); - if (!call.hasValue()) return cannot_transform; - calls.push_back(call.getValue()); - } - - llvm::SmallVector callees; - callees.reserve(regions.size()); - - CallOp call0 = calls[0]; - int num_args = call0.getNumOperands(); - - // Collect arguments of the first call. - llvm::SmallVector call0_args; - call0_args.reserve(num_args); - for (Value arg : call0.getArgOperands()) { - if (auto cast = dyn_cast_or_null(arg.getDefiningOp())) - arg = cast.getOperand(); - call0_args.push_back(arg); - } - - // Match arguments of rest of the calls with those of the first call. - for (auto call : calls) { - if (call != call0 && !MatchCallArgs(call, call0_args)) - return cannot_transform; - callees.push_back(call.getCallee()); - } - - return {true, callees, call0_args}; + can_transform = true; + callee_names = {call0.getValue().getCallee(), call1.getValue().getCallee()}; } // Transform IfRegionOp to IfOp. LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { - const TrivialTransformInfo tti = AnalyzeForTrivialTransform( - {&if_region.then_branch(), &if_region.else_branch()}); + llvm::SmallVector extern_values; + + // For IfOp, arguments of calls in the then and else regions match if they + // are the same value. + auto if_matcher = [&](Value first, Region&, Value second, Region&) { + if (first != second) return false; + + // collect the call arguments post lookup through cast Op's + extern_values.push_back(first); + return true; + }; + + const TrivialTransformInfo tti(if_region.then_branch(), + if_region.else_branch(), if_matcher); std::string then_name, else_name; - llvm::SmallVector extern_values; if (tti.can_transform) { // We can transform to functional form trivially without outlining. then_name = tti.callee_names[0].str(); else_name = tti.callee_names[1].str(); - extern_values = tti.call_args; } else { // Collect external values that are used within the else and then bodies. - extern_values = CollectExternValues( - {&if_region.then_branch(), &if_region.else_branch()}); + extern_values = + CollectExternValues(if_region.then_branch(), if_region.else_branch()); // These external values need to be added as inputs to the generated If. The // order is determined by the order of these values the `extern_vales`. - // Build the type for the outlined function. - llvm::SmallVector input_types; - input_types.reserve(extern_values.size()); - for (auto input : extern_values) input_types.push_back(input.getType()); - - FunctionType func_type = FunctionType::get( - input_types, if_region.getResultTypes(), if_region.getContext()); - // Create 2 new functions with the input signature matching this order, // and outline the `then` and `else` regions by moving the bodies of these // regions into these functions. Replace tf.yield with a regular return. then_name = GetName(if_region, "_then"); - ExtractSingleBlockRegion(if_region.then_branch(), func_type, then_name, - extern_values, worklist); + ExtractSingleBlockRegion(if_region.then_branch(), then_name, extern_values, + worklist, /*extern_values_passthrough=*/false); else_name = GetName(if_region, "_else"); - ExtractSingleBlockRegion(if_region.else_branch(), func_type, else_name, - extern_values, worklist); + ExtractSingleBlockRegion(if_region.else_branch(), else_name, extern_values, + worklist, /*extern_values_passthrough=*/false); } // Once we have the `then` and `else` functions ready (either outlined or @@ -297,24 +300,111 @@ LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { auto if_op = builder.create( if_region.getLoc(), if_region.getResultTypes(), if_region.cond(), extern_values, then_name, else_name, if_region.is_stateless()); + CopyUnderscoredAttributes(if_region, if_op); if_region.replaceAllUsesWith(if_op.getResults()); if_region.erase(); return success(); } +// Transform WhileRegion to WhileOp. +LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( + WhileRegionOp while_region) { + // For While, the arguments of the calls in the body and cond regions match + // if they are region arguments with the same region argument numbers. If the + // 2 calls have the same value (an extern value) used an an argument, we + // cannot do a trivial transformation because post transform, we will need to + // pass this extern value as an argument to the function, so we cannot use the + // existing function as is. + auto while_matcher = [](Value first, Region& first_region, Value second, + Region& second_region) { + if (!first.isa() || !second.isa()) + return false; + BlockArgument first_block_arg = first.cast(); + BlockArgument second_block_arg = second.cast(); + + // 2 block arguments will match if they are the same argument number, and + // are block arguments of the corresponding containing regions. + return first_block_arg.getArgNumber() == second_block_arg.getArgNumber() && + first_block_arg.getParentBlock() == &first_region.front() && + second_block_arg.getParentBlock() == &second_region.front(); + }; + + const TrivialTransformInfo tti(while_region.cond(), while_region.body(), + while_matcher); + + // All existing inputs to while region are inputs to the functional while. + auto new_inputs = llvm::to_vector<4>(while_region.getOperands()); + + // All existing results will also be generated by the functional while. + auto new_result_types = llvm::to_vector<4>(while_region.getResultTypes()); + + std::string cond_name, body_name; + if (tti.can_transform) { + // We can transform to functional form trivially without outlining. + cond_name = tti.callee_names[0].str(); + body_name = tti.callee_names[1].str(); + } else { + // The WhileRegion regions can refer to either arguments of the region, or + // external values implicitly captured by the region. When converting to + // functional form, all such external values need to become function + // arguments of the outlined functions, and become pass through values in + // the outlined body function. So when outlining the while body, in addition + // to the region arguments, all these external references need to be added + // as function arguments. + llvm::SmallVector extern_values = + CollectExternValues(while_region.cond(), while_region.body()); + + // Outline the `cond` and `body` regions by moving the bodies of these + // regions into new functions. Replace tf.yield with a regular return. + cond_name = GetName(while_region, "_cond"); + ExtractSingleBlockRegion(while_region.cond(), cond_name, extern_values, + worklist, /*extern_values_passthrough=*/false); + + body_name = GetName(while_region, "_body"); + ExtractSingleBlockRegion(while_region.body(), body_name, extern_values, + worklist, /*extern_values_passthrough=*/true); + + // All extern values become additional inputs and additional output types + // for the functional while. + new_inputs.append(extern_values.begin(), extern_values.end()); + for (auto ext : extern_values) new_result_types.push_back(ext.getType()); + } + + // Once we have the `cond` and `body` functions ready (either outlined or + // existing ones), replace the region based op with a functional op. + OpBuilder builder(while_region); + auto while_op = builder.create( + while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name, + while_region.parallel_iterations(), while_region.is_stateless()); + CopyUnderscoredAttributes(while_region, while_op); + + // Redirect old results to new results. + for (auto it : llvm::zip( + while_region.getResults(), + while_op.getResults().take_front(while_region.getNumResults()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + + while_region.erase(); + return success(); +} + void RegionControlFlowToFunctional::runOnOperation() { ModuleOp module = getOperation(); // Seed worklist with all functions in the module. worklist = llvm::to_vector<4>(module.getOps()); - while (!worklist.empty()) { FuncOp function = worklist.pop_back_val(); auto result = function.walk([&](Operation* op) { - if (IfRegionOp if_region = llvm::dyn_cast(op)) { + if (auto if_region = llvm::dyn_cast(op)) { if (failed(ConvertIfOp(if_region))) { - if_region.emitOpError() << " failed to convert to functional form"; + op->emitOpError() << "failed to convert to functional form"; + return WalkResult::interrupt(); + } + } else if (auto while_region = llvm::dyn_cast(op)) { + if (failed(ConvertWhileOp(while_region))) { + op->emitOpError() << "failed to convert to functional form"; return WalkResult::interrupt(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index b16868311f0..ef75f90d5c1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -32,12 +33,14 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/core/platform/logging.h" namespace mlir { @@ -45,10 +48,11 @@ namespace TFDevice { namespace { constexpr char kDeviceAttr[] = "device"; constexpr char kReplicaIdAttr[] = "_xla_replica_id"; +constexpr char kDeviceOrdinalAttr[] = "device_ordinal"; struct ReplicateToIslandPass - : public PassWrapper { - void runOnFunction() override; + : public PassWrapper> { + void runOnOperation() override; }; // Returns whether op requires `_xla_replica_id` attribute. @@ -57,29 +61,207 @@ bool RequiresReplicaIDAttribute(Operation* op) { TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op); } -// Adds integer attribute that represents replica id for replicated ops that -// require replica id attribute. -void AddReplicaIdToOpsInReplicatedRegion(OpBuilder* builder, Region* region, - const int replica_id) { - region->walk([&](Operation* replicated_op) { - if (RequiresReplicaIDAttribute(replicated_op)) - replicated_op->setAttr(kReplicaIdAttr, - builder->getI32IntegerAttr(replica_id)); +bool RequiresDeviceOrdinalAttribute(Operation* op) { + return llvm::isa(op) || + llvm::isa(op); +} + +// Checks if a region contains ops that are replica variant. +bool HasReplicaVariantOps(Region& region, + const llvm::Optional& devices) { + auto result = region.walk([&](Operation* op) { + if (RequiresReplicaIDAttribute(op) || + (devices.hasValue() && RequiresDeviceOrdinalAttribute(op))) + return WalkResult::interrupt(); + + if (auto launch = dyn_cast(op)) + if (devices.hasValue() && devices.getValue().get(launch.device())) + return WalkResult::interrupt(); + + return WalkResult::advance(); }); + return result.wasInterrupted(); +} + +// Collects all functions reachable from a region, including transitive ones. +llvm::SmallPtrSet GetReachableFunctionsFromRegion(ModuleOp module, + Region& region) { + llvm::SmallPtrSet visited_functions; + + SymbolTable symbol_table(module); + auto symbol_uses = symbol_table.getSymbolUses(®ion); + if (!symbol_uses) return {}; + + for (auto& use : *symbol_uses) + if (auto func = + symbol_table.lookup(use.getSymbolRef().getRootReference())) + visited_functions.insert(func); + + llvm::SmallVector functions_to_visit(visited_functions.begin(), + visited_functions.end()); + while (!functions_to_visit.empty()) { + llvm::SmallVector new_functions_to_visit; + + for (FuncOp function_to_visit : functions_to_visit) { + auto func_symbol_uses = + symbol_table.getSymbolUses(function_to_visit.getCallableRegion()); + if (!func_symbol_uses) continue; + + for (auto& use : *func_symbol_uses) + if (auto func = symbol_table.lookup( + use.getSymbolRef().getRootReference())) + if (visited_functions.insert(func).second) + new_functions_to_visit.push_back(func); + } + + functions_to_visit.swap(new_functions_to_visit); + } + + return visited_functions; +} + +// Collects all functions and transitive functions reachable from region that +// contain replicate variant ops. +llvm::SmallDenseMap GetReachableFunctionsToClone( + ModuleOp module, Region& region, + const llvm::Optional& devices) { + llvm::SmallPtrSet reachable_functions = + GetReachableFunctionsFromRegion(module, region); + + llvm::SmallDenseMap functions_to_clone; + llvm::SmallVector functions_to_visit; + for (FuncOp func : reachable_functions) { + if (!func.getCallableRegion()) continue; + if (HasReplicaVariantOps(*func.getCallableRegion(), devices)) { + functions_to_clone.insert({func.getName(), func}); + functions_to_visit.push_back(func); + } + } + + while (!functions_to_visit.empty()) { + llvm::SmallVector new_functions_to_visit; + + for (FuncOp func_to_visit : functions_to_visit) { + auto func_uses = func_to_visit.getSymbolUses(module); + if (!func_uses) continue; + for (auto use : *func_uses) { + auto parent_func = use.getUser()->getParentOfType(); + if (!parent_func || !reachable_functions.contains(parent_func) || + !functions_to_clone.insert({parent_func.getName(), parent_func}) + .second) + continue; + new_functions_to_visit.push_back(parent_func); + } + } + + functions_to_visit.swap(new_functions_to_visit); + } + + return functions_to_clone; +} + +struct FuncOldNameAndClone { + StringRef old_name; + FuncOp clone; +}; + +// Replaces all symbol uses with cloned functions, for `region` and across the +// cloned functions themselves. +LogicalResult UpdateSymbolUsesWithClones( + SymbolTable& symbol_table, ModuleOp module, Region& region, + llvm::MutableArrayRef cloned_functions) { + llvm::SmallVector, 4> old_to_new_names; + old_to_new_names.reserve(cloned_functions.size()); + for (auto& cloned_function : cloned_functions) + old_to_new_names.push_back( + {cloned_function.old_name, cloned_function.clone.getName()}); + + for (const auto& old_to_new_name : old_to_new_names) { + if (failed(symbol_table.replaceAllSymbolUses( + old_to_new_name.first, old_to_new_name.second, ®ion))) + return failure(); + + for (auto& cloned_function : cloned_functions) + if (failed(symbol_table.replaceAllSymbolUses( + old_to_new_name.first, old_to_new_name.second, + cloned_function.clone.getCallableRegion()))) + return failure(); + } + return success(); +} + +// Collects TPU device ordinal for outside compilation communication ops. This +// currently assumes outside compilation only uses `TPU_REPLICATED_CORE_0` +// aliased device for the device computation. +llvm::Optional GetDeviceOrdinal( + const llvm::Optional& devices, Location loc, + unsigned replica_id) { + int64_t device_ordinal = 0; + if (devices.hasValue()) { + if (auto tpu_replica_0 = devices.getValue().get("TPU_REPLICATED_CORE_0")) { + llvm::StringRef tpu_device = tpu_replica_0.cast()[replica_id] + .cast() + .getValue(); + if (succeeded(tensorflow::GetDeviceOrdinalFromDeviceString( + loc, tpu_device, &device_ordinal))) { + return llvm::Optional(device_ordinal); + } + } + } + return llvm::None; +} + +// Updates replica variant ops in a region based on replica `replica_id`. +// TODO(b/157624749): Replace this with better abstraction to differentiate ops +// for different replicas. Some ops, such as XlaHostCompute op or TPU Embedding +// ops, require replica id to be added as an op attribute to be used during +// execution. Handle such ops separately and add an integer attribute that +// represents replica id. +LogicalResult UpdateRegionReplicateVariantOps( + OpBuilder& builder, Location loc, Region& region, int replica_id, + llvm::MutableArrayRef cloned_functions, + const llvm::Optional& devices) { + llvm::Optional device_ordinal = + GetDeviceOrdinal(devices, loc, replica_id); + + auto update_replicate_variant_ops = [&](Operation* op) { + // Add replica id. + if (RequiresReplicaIDAttribute(op)) + op->setAttr(kReplicaIdAttr, builder.getI32IntegerAttr(replica_id)); + + if (!devices.hasValue()) return; + + // Map aliased devices to explicit devices based on replica. + if (auto launch = dyn_cast(op)) + if (auto device_by_replica = devices.getValue().get(launch.device())) + launch.setAttr( + kDeviceAttr, + device_by_replica.cast()[replica_id].cast()); + + // Add device ordinal. + if (device_ordinal && RequiresDeviceOrdinalAttribute(op)) + op->setAttr(kDeviceOrdinalAttr, + builder.getI64IntegerAttr(*device_ordinal)); + }; + + region.walk(update_replicate_variant_ops); + for (auto& cloned_function : cloned_functions) + cloned_function.clone.getCallableRegion()->walk( + update_replicate_variant_ops); + + return success(); } // Creates islands per replica from `tf_device.replicate` region. If for a // `tf_device.launch` op the device is an aliased device of the // `tf_device.replicate`, the device will be remapped to an explicit device // for the associated replica island. -llvm::SmallVector ExpandReplicateIntoReplicas( - const Dialect* tf_dialect, OpBuilder* builder, +LogicalResult ExpandReplicateIntoReplicas( + const Dialect* tf_dialect, OpBuilder& builder, ModuleOp module, tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op, - int num_replicas) { - auto devices = replicate_op.devices(); - const bool has_devices = devices.hasValue(); - llvm::SmallVector replicas; + int num_replicas, llvm::SmallVectorImpl& replicas) { replicas.reserve(num_replicas); + auto devices = replicate_op.devices(); // Collect result types and operands. Operation& terminator = replicate_op.GetBody().back(); @@ -88,16 +270,30 @@ llvm::SmallVector ExpandReplicateIntoReplicas( llvm::SmallVector replica_inputs(island_op.controlInputs()); // Replace replicate terminator with YieldOp. - builder->setInsertionPoint(&terminator); - builder->create(terminator.getLoc(), - terminator.getOperands()); + builder.setInsertionPoint(&terminator); + builder.create(terminator.getLoc(), + terminator.getOperands()); terminator.erase(); - builder->setInsertionPoint(island_op); + auto funcs_to_clone = + GetReachableFunctionsToClone(module, replicate_op.body(), devices); + SymbolTable symbol_table(module); + + builder.setInsertionPoint(island_op); BlockAndValueMapping mapping; for (int i : llvm::seq(0, num_replicas)) { + // Clone reachable functions with replica variant ops. + llvm::SmallVector cloned_functions; + cloned_functions.reserve(funcs_to_clone.size()); + for (auto& func_to_clone : funcs_to_clone) { + auto cloned_function = func_to_clone.getSecond().clone(); + symbol_table.insert(cloned_function, module.end()); + cloned_functions.push_back( + {func_to_clone.getSecond().getName(), cloned_function}); + } + // Create new island for replica. - auto replica = builder->create( + auto replica = builder.create( island_op.getLoc(), output_types, control_type, replica_inputs); // Map block arg to replica arg. @@ -109,28 +305,19 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // Copy over replicate region into replica island. replicate_op.body().cloneInto(&replica.body(), mapping); - // TODO(b/157624749): Replace this with better abstraction to - // differentiate ops for different replicas. - // Some ops, such as XlaHostCompute op or TPU Embedding ops, require - // replica id to be added as an op attribute to be used during - // execution. Handle such ops separately and add an integer attribute - // that represents replica id. - AddReplicaIdToOpsInReplicatedRegion(builder, &replica.body(), i); + if (failed(UpdateSymbolUsesWithClones(symbol_table, module, replica.body(), + cloned_functions))) + return failure(); - // Map aliased devices to explicit devices based on replica. - if (has_devices) { - replica.walk([&](tf_device::LaunchOp launch) { - if (auto device_by_replica = devices.getValue().get(launch.device())) - launch.setAttr( - kDeviceAttr, - device_by_replica.cast()[i].cast()); - }); - } + if (failed(UpdateRegionReplicateVariantOps( + builder, replicate_op.getLoc(), replica.body(), + /*replica_id=*/i, cloned_functions, devices))) + return failure(); replicas.push_back(replica); } - return replicas; + return success(); } // Creates islands per replica from `tf_device.replicate` region and remap @@ -183,17 +370,19 @@ llvm::SmallVector ExpandReplicateIntoReplicas( // }) {device = "/DEVICE:3"} : () -> tensor // tf_executor.yield %a1, %b1 : tensor, tensor // } -void CreateIslandsFromReplicate(const Dialect* tf_dialect, - tf_executor::GraphOp graph_op, - tf_executor::IslandOp island_op, - tf_device::ReplicateOp replicate_op) { +LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, + ModuleOp module, + tf_executor::GraphOp graph_op, + tf_executor::IslandOp island_op, + tf_device::ReplicateOp replicate_op) { OpBuilder builder(island_op); const int num_replicas = replicate_op.n().getLimitedValue(); // Create islands per replica. - llvm::SmallVector replicas = - ExpandReplicateIntoReplicas(tf_dialect, &builder, island_op, replicate_op, - num_replicas); + llvm::SmallVector replicas; + if (failed(ExpandReplicateIntoReplicas(tf_dialect, builder, module, island_op, + replicate_op, num_replicas, replicas))) + return failure(); // Collect all replica results. llvm::SmallVector replicas_outputs(replicate_op.getNumResults(), @@ -244,36 +433,41 @@ void CreateIslandsFromReplicate(const Dialect* tf_dialect, } island_op.erase(); + return success(); } -// Finds islands with a single `tf_device.replicate` and create individual -// islands per replica of the replicate. -void LowerSingleIslandReplicateToIslands(const Dialect* tf_dialect, - tf_executor::GraphOp graph_op, - tf_executor::IslandOp island_op) { - if (!island_op.WrapsSingleOp()) return; - - if (auto replicate_op = - llvm::dyn_cast(&island_op.GetBody().front())) - CreateIslandsFromReplicate(tf_dialect, graph_op, island_op, replicate_op); -} - -void ReplicateToIslandPass::runOnFunction() { +void ReplicateToIslandPass::runOnOperation() { + auto module = getOperation(); const Dialect* tf_dialect = getContext().getRegisteredDialect("tf"); if (!tf_dialect) { - signalPassFailure(); - getFunction().emitError() << "'tf' dialect is not registered"; + module.emitError() << "'tf' dialect is not registered"; + return signalPassFailure(); } - getFunction().walk([&](tf_executor::GraphOp graph_op) { - for (auto island_op : - llvm::make_early_inc_range(graph_op.getOps())) - LowerSingleIslandReplicateToIslands(tf_dialect, graph_op, island_op); + // Find islands with a single `tf_device.replicate` and create individual + // islands per replica of the replicate. + llvm::SmallVector replicate_op_islands; + module.walk([&](tf_executor::GraphOp graph_op) { + for (auto island_op : graph_op.getOps()) { + if (!island_op.WrapsSingleOp()) continue; + + if (isa(&island_op.GetBody().front())) + replicate_op_islands.push_back(island_op); + } }); + + for (tf_executor::IslandOp island_op : replicate_op_islands) { + auto graph_op = island_op.getParentOfType(); + auto replicate_op = + cast(island_op.GetBody().front()); + if (failed(CreateIslandsFromReplicate(tf_dialect, module, graph_op, + island_op, replicate_op))) + return signalPassFailure(); + } } } // anonymous namespace -std::unique_ptr> CreateReplicateToIslandPass() { +std::unique_ptr> CreateReplicateToIslandPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index 21d74d81b20..7e8e9ee30c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -36,7 +36,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -61,7 +61,9 @@ struct ResourceDeviceInference // A class that records each resource's device assignment in a function. class PerFunctionResult { public: - explicit PerFunctionResult(FuncOp func_op) : alias_analysis_(func_op) {} + explicit PerFunctionResult( + FuncOp func_op, const TF::ResourceAliasAnalysis::Info& alias_analysis) + : alias_analysis_(alias_analysis) {} // Returns the recorded device assignment for a resource, if any. llvm::Optional DeviceForResource( @@ -105,7 +107,7 @@ class PerFunctionResult { private: llvm::SmallDenseMap resource_id_to_device_; - TF::ResourceAliasAnalysis alias_analysis_; + const TF::ResourceAliasAnalysis::Info& alias_analysis_; }; // Tries to record device assignment for a resource. @@ -193,46 +195,50 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, void ResourceDeviceInference::runOnOperation() { auto module = getOperation(); + const auto& resource_alias_analysis = + getAnalysis(); + llvm::SmallDenseMap per_function_results; llvm::SetVector worklist; module.walk([&](FuncOp func_op) { worklist.insert(func_op); - per_function_results.try_emplace(func_op, func_op); + per_function_results.try_emplace( + func_op, func_op, resource_alias_analysis.GetAnalysisForFunc(func_op)); }); // Helper that propagates an op's recorded operand device assignments to its // called function's arguments. auto propagate_operands_to_callee_arguments = [&](Operation* caller, Operation::operand_range caller_operands, - llvm::StringRef called_func_name, - const PerFunctionResult& caller_res) { - auto callee = - llvm::dyn_cast(module.lookupSymbol(called_func_name)); - assert(callee); - auto& callee_res = per_function_results.find(callee)->getSecond(); - bool callee_needs_recompute = false; - for (auto operand_and_argument : - llvm::zip(caller_operands, callee.getArguments())) { - if (!mlir::getElementTypeOrSelf( - std::get<0>(operand_and_argument).getType()) - .isa()) { - continue; + ArrayRef callees, const PerFunctionResult& caller_res) { + for (FuncOp callee : callees) { + assert(callee); + auto& callee_res = per_function_results.find(callee)->getSecond(); + bool callee_needs_recompute = false; + for (auto operand_and_argument : + llvm::zip(caller_operands, callee.getArguments())) { + if (!mlir::getElementTypeOrSelf( + std::get<0>(operand_and_argument).getType()) + .isa()) { + continue; + } + auto device = + caller_res.DeviceForResource(std::get<0>(operand_and_argument)); + if (!device) continue; + if (failed(AddResourceDeviceAndEmitError( + std::get<1>(operand_and_argument), *device, caller, + &callee_res, &callee_needs_recompute))) { + return failure(); + } } - auto device = - caller_res.DeviceForResource(std::get<0>(operand_and_argument)); - if (!device) continue; - if (failed(AddResourceDeviceAndEmitError( - std::get<1>(operand_and_argument), *device, caller, - &callee_res, &callee_needs_recompute))) { - return failure(); + // If the callee recording is modified, make sure that it will be + // reprocessed. + if (callee_needs_recompute) { + worklist.insert(callee); } } - // If the callee recording is modified, make sure that it will be - // reprocessed. - if (callee_needs_recompute) { - worklist.insert(callee); - } return success(); }; + while (!worklist.empty()) { auto func_op = worklist.back(); worklist.pop_back(); @@ -245,18 +251,14 @@ void ResourceDeviceInference::runOnOperation() { auto walk_res = func_op.walk([&](Operation* op) { if (auto while_op = llvm::dyn_cast(op)) { if (failed(propagate_operands_to_callee_arguments( - while_op, while_op.getOperands(), while_op.body(), func_res)) || - failed(propagate_operands_to_callee_arguments( - while_op, while_op.getOperands(), while_op.cond(), func_res))) { + while_op, while_op.getOperands(), + {while_op.body_func(), while_op.cond_func()}, func_res))) return WalkResult::interrupt(); - } } else if (auto if_op = llvm::dyn_cast(op)) { if (failed(propagate_operands_to_callee_arguments( - if_op, if_op.input(), if_op.then_branch(), func_res)) || - failed(propagate_operands_to_callee_arguments( - if_op, if_op.input(), if_op.else_branch(), func_res))) { + if_op, if_op.input(), {if_op.then_func(), if_op.else_func()}, + func_res))) return WalkResult::interrupt(); - } } return WalkResult::advance(); }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 6a67f0bea0a..702455d156d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -558,15 +558,13 @@ void AddLoadsStoresOutsideControlFlowOp( auto operand = caller->getOperand(index); builder.setInsertionPoint(caller); new_operands[index] = builder.create( - caller->getLoc(), ArrayRef{new_type}, ArrayRef{operand}, - ArrayRef{}); + caller->getLoc(), ArrayRef{new_type}, ArrayRef{operand}); caller->setOperand(index, new_operands[index]); if (updated_index < 0) continue; builder.setInsertionPointAfter(caller); builder.create( caller->getLoc(), ArrayRef{}, - ArrayRef{operand, caller->getResult(updated_index)}, - ArrayRef{}); + ArrayRef{operand, caller->getResult(updated_index)}); } } @@ -629,8 +627,6 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { }); // Recreate the while op. OpBuilder builder(while_op); - auto new_output_shapes = FilterRange>( - while_op.output_shapes().getValue(), resource_arg_uses); // Now use the filtered original operands, which will be replaced by // AddLoadsStoresOutsideControlFlowOp(). auto new_while = builder.create( @@ -638,8 +634,7 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { FilterRange(while_op.getOperands(), resource_arg_uses), while_op.getAttrs()); - // Prepare for AddLoadsStoresOutsideControlFlowOp() and update - // new_output_shapes. + // Prepare for AddLoadsStoresOutsideControlFlowOp(). llvm::SmallDenseMap> arg_data_type_and_updated_output_index; for (const auto& entry : remaining_resource_data_types) { @@ -649,16 +644,11 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { : entry.getFirst(); arg_data_type_and_updated_output_index[entry.getFirst()] = { entry.getSecond(), update_index}; - if (!new_output_shapes.empty()) { - new_output_shapes[entry.getFirst()] = - tensorflow::ConvertTypeToTensorShapeAttr(entry.getSecond()); - } } AddLoadsStoresOutsideControlFlowOp(new_while, arg_data_type_and_updated_output_index); - new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); // Replace uses. - for (int64_t i = 0; i < old_to_new_indices.size(); ++i) { + for (int64_t i = 0, end = old_to_new_indices.size(); i < end; ++i) { if (old_to_new_indices[i] >= 0) { while_op.getResult(i).replaceAllUsesWith( new_while.getResult(old_to_new_indices[i])); @@ -687,10 +677,12 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { auto retval = func.front().getTerminator()->getOperand(result_index); assert(result.getType() == retval.getType()); auto aliasing_arg = retval.dyn_cast(); + if (!aliasing_arg) + return op.emitOpError("unsupported output: ") + << "resource does not alias input"; if (common_aliasing_arg_num == kUnassigned) common_aliasing_arg_num = aliasing_arg.getArgNumber(); - if (!aliasing_arg || - aliasing_arg.getArgNumber() != common_aliasing_arg_num) + if (aliasing_arg.getArgNumber() != common_aliasing_arg_num) return op.emitOpError("unsupported output: ") << "resource does not alias a single input"; } @@ -760,8 +752,11 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { for (auto branch : branches) { auto new_retvals = llvm::to_vector<4>(branch.front().getTerminator()->getOperands()); + new_retvals.resize(new_retvals.size() + resource_arg_to_new_output.size()); for (const auto& entry : resource_arg_to_new_output) { - new_retvals.push_back(branch.getArgument(entry.getFirst())); + int64_t resource_arg_index = entry.getFirst(); + int64_t output_index = entry.getSecond(); + new_retvals[output_index] = branch.getArgument(resource_arg_index); } auto old_return = branch.front().getTerminator(); OpBuilder builder(old_return); @@ -799,7 +794,7 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { AddLoadsStoresOutsideControlFlowOp(new_op, arg_data_type_and_updated_output_index); // Replace uses. - for (int64_t i = 0; i < old_to_new_output_indices.size(); ++i) { + for (int64_t i = 0, end = old_to_new_output_indices.size(); i < end; ++i) { if (old_to_new_output_indices[i] >= 0) { op.getResult(i).replaceAllUsesWith( new_op.getResult(old_to_new_output_indices[i])); @@ -943,7 +938,8 @@ void UpdatePartitionedCallOpWithNewCallee( AddLoadsStoresOutsideControlFlowOp( new_call, lifting_info.arg_data_type_and_updated_output_index); // Replace uses. - for (int64_t i = 0; i < lifting_info.old_to_new_output_indices.size(); ++i) { + for (int64_t i = 0, end = lifting_info.old_to_new_output_indices.size(); + i < end; ++i) { if (lifting_info.old_to_new_output_indices[i] >= 0) { call_op.getResult(i).replaceAllUsesWith( new_call.getResult(lifting_info.old_to_new_output_indices[i])); @@ -987,8 +983,8 @@ LogicalResult HoistForFunctionalControlFlow( RemoveIdentity(block); for (Operation& op : llvm::make_early_inc_range(*block)) { if (auto while_op = llvm::dyn_cast(&op)) { - auto body = llvm::cast(module.lookupSymbol(while_op.body())); - auto cond = llvm::cast(module.lookupSymbol(while_op.cond())); + auto body = while_op.body_func(); + auto cond = while_op.cond_func(); // Recursively handle the nested control flow. HoistForFunctionalControlFlow(&body.front(), module, lifted_partitioned_call_callees); @@ -996,10 +992,8 @@ LogicalResult HoistForFunctionalControlFlow( lifted_partitioned_call_callees); if (failed(HandleWhileLoop(while_op, body, cond))) return failure(); } else if (auto if_op = llvm::dyn_cast(&op)) { - auto then_branch = - llvm::cast(module.lookupSymbol(if_op.then_branch())); - auto else_branch = - llvm::cast(module.lookupSymbol(if_op.else_branch())); + auto then_branch = if_op.then_func(); + auto else_branch = if_op.else_func(); // Recursively handle the nested control flow. HoistForFunctionalControlFlow(&then_branch.front(), module, lifted_partitioned_call_callees); @@ -1020,12 +1014,10 @@ LogicalResult HoistForFunctionalControlFlow( } if (failed(HandleCaseOrIfOp(case_op, branch_functions))) return failure(); } else if (auto call_op = llvm::dyn_cast(&op)) { - if (!call_op.f().isa()) { + auto callee = call_op.func(); + if (!callee) return call_op.emitOpError( "resource lifting does not support call with nested references."); - } - auto callee = llvm::cast( - module.lookupSymbol(call_op.f().getRootReference())); if (failed(HandlePartitionedCallOp(call_op, callee, module, lifted_partitioned_call_callees))) { // Nested control flow handling is done in HandlePartitionedCallOp(). @@ -1033,8 +1025,7 @@ LogicalResult HoistForFunctionalControlFlow( } } else if (auto call_op = llvm::dyn_cast(&op)) { - auto callee = llvm::cast(module.lookupSymbol(call_op.f())); - if (failed(HandlePartitionedCallOp(call_op, callee, module, + if (failed(HandlePartitionedCallOp(call_op, call_op.func(), module, lifted_partitioned_call_callees))) { return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index f9c81634ae5..597fbe2c0b1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -130,25 +131,28 @@ bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) { !IsSupportedNonTFOp(use.getOwner()); } -// Inserts tf.Cast operation when changing the type of a result if the user is -// not a TF operation, as we can't guarantee that the new type will be OK. -void AddCastBackForUnsupportedNonTFUses(Operation* op, Value result, - Dialect* tf_dialect, Type old_type) { - // A tf.Cast operation is lazily created on the first uses that isn't a TF - // operation. +// Updates the result of an operation to a new inferred type. Also inserts +// tf.Cast operation for uses that are incompatible with the new type. +void UpdateTypeAndInsertIncompatibleUseCasts(Dialect* tf_dialect, Type new_type, + Operation* op, Value result) { + // A tf.Cast operation is lazily created on the first use requires a cast. TF::CastOp cast_op; auto get_cast_op = [&]() { if (!cast_op) { OpBuilder b(op); b.setInsertionPointAfter(op); - cast_op = b.create(op->getLoc(), old_type, result, + cast_op = b.create(op->getLoc(), result.getType(), result, /*truncate=*/b.getBoolAttr(false)); } return Value(cast_op); }; + // First insert cast back for uses that need a cast and then + // update the type. for (OpOperand& use : make_early_inc_range(result.getUses())) { if (NeedsCastBack(use, tf_dialect)) use.set(get_cast_op()); } + + result.setType(new_type); } // Extracts a PartialTensorShape from the MLIR type. @@ -210,36 +214,49 @@ bool CanBeRefined(Type type) { shape_type.getElementType().isa()); } +// Returns whether `original_type` type can be refined with +// `potential_refined_type` type. +bool CanRefineTypeWith(Type original_type, Type potential_refined_type) { + if (original_type == potential_refined_type || !CanBeRefined(original_type)) + return false; + + auto shape_type = potential_refined_type.dyn_cast(); + if (!shape_type) return false; + if (shape_type.hasRank()) return true; + + auto element_type_with_subtype = + shape_type.getElementType().dyn_cast(); + return element_type_with_subtype && + !element_type_with_subtype.GetSubtypes().empty(); +} + +// Refines the type of `result` of `op` using the type `potential_refined_type`. +// Return true if the type was changed. +bool RefineResultType(Operation* op, Value result, + Type potential_refined_type) { + if (!CanRefineTypeWith(result.getType(), potential_refined_type)) + return false; + + UpdateTypeAndInsertIncompatibleUseCasts(op->getDialect(), + potential_refined_type, op, result); + return true; +} + // Infers the shape from a (Stateful)PartionedCall operation by looking up the // called function and propagating the return type. -bool InferShapeForCall(Operation* op) { - auto call_op = cast(op); - CallInterfaceCallable callable = call_op.getCallableForCallee(); - SymbolRefAttr sym = callable.dyn_cast(); - if (!sym) return false; - FuncOp func = dyn_cast(SymbolTable::lookupNearestSymbolFrom(op, sym)); +bool InferShapeForCall(CallOpInterface call_op) { + FuncOp func = dyn_cast(call_op.resolveCallable()); if (!func) return false; + Operation* op = call_op.getOperation(); bool changed = false; // Map each of the results of the call to the returned type of the // function. for (auto result : zip(op->getResults(), func.getType().getResults())) { - if (std::get<0>(result).getType() == std::get<1>(result)) continue; - // Skip already statically shaped results. - if (!CanBeRefined(std::get<0>(result).getType())) continue; - - auto shaped_type = std::get<0>(result).getType().cast(); - auto new_type = std::get<1>(result).dyn_cast(); - if (!new_type) continue; - - // Inserts a cast back to the original type if any user is not in the - // TF dialect. - AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), - op->getDialect(), shaped_type); - // Finally we inferred the shape and replace the type for this result. - std::get<0>(result).setType(new_type); - changed = true; + changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) || + changed; } + return changed; } @@ -265,12 +282,43 @@ bool InferShapeForCast(CastOp op, Dialect* tf_dialect) { auto new_type = RankedTensorType::get( ranked_op_type.getShape(), result.getType().cast().getElementType()); - auto old_type = result.getType(); - result.setType(new_type); - AddCastBackForUnsupportedNonTFUses(op, op.getResult(), tf_dialect, old_type); + + UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect, new_type, op, + op.getResult()); return true; } +// Infer the shape IfOp outputs based on the shapes of the then and else +// function result types. +bool InferShapeForIf(IfOp op) { + bool changed = false; + auto then_results = op.then_func().getType().getResults(); + auto else_results = op.else_func().getType().getResults(); + for (auto it : llvm::zip(op.getResults(), then_results, else_results)) { + // If then and else types do not match, skip refinement for that result. + if (std::get<1>(it) != std::get<2>(it)) continue; + changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed; + } + return changed; +} + +// Infer the shape IfRegion outputs based on the shapes of the then and else +// yields. +bool InferShapeForIfRegion(IfRegionOp op) { + bool changed = false; + + Operation* then_yield = op.then_branch().front().getTerminator(); + Operation* else_yield = op.else_branch().front().getTerminator(); + for (auto result : zip(op.getResults(), then_yield->getOperandTypes(), + else_yield->getOperandTypes())) { + // If then and else types do not match, skip refinement for that result. + if (std::get<1>(result) != std::get<2>(result)) continue; + changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) || + changed; + } + return changed; +} + bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti, Dialect* tf_dialect) { Operation* op = infer_ti.getOperation(); @@ -291,12 +339,8 @@ bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti, for (auto result : zip(op->getResults(), inferred)) { if (std::get<0>(result).getType() == std::get<1>(result)) continue; - // Inserts a cast back to the original type if any user is not in the - // TF dialect. - AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), - op->getDialect(), std::get<1>(result)); - // Finally we inferred the shape and replace the type for this result. - std::get<0>(result).setType(std::get<1>(result)); + UpdateTypeAndInsertIncompatibleUseCasts( + op->getDialect(), std::get<1>(result), op, std::get<0>(result)); changed = true; } return changed; @@ -485,32 +529,37 @@ class ShapeInference { // 1) They are never reused, ie. having a single use in module. // 2) Their input types match those of their parent ops (excluding inputs // like predicate). - // Returns a boolean indicating whether any change has been applied. - LogicalResult RefineShapeForControlFlowFunc(FuncOp func, - ArrayRef input_types, - int64_t max_iteration); - - // Propagate the shapes to the functions named. LogicalResult PropagateShapeToFunctions( ModuleOp module, Operation::operand_type_range input_types, - ArrayRef func_names, int64_t max_iteration); + ArrayRef functions, int64_t max_iteration); + + // Propagates shapes to regions given the shapes of the inputs of the regions. + // All regions provided in `regions` are assumed to have inputs of type + // `input_types`. + LogicalResult PropagateShapeToRegions( + Operation::operand_type_range input_types, ArrayRef regions, + int64_t max_iteration); // Shape propagation for call/control flow ops. LogicalResult PropagateShapeIntoAttachedFunctions(Operation* op, int64_t max_iteration); + // Shape propagation for region based control flow. + LogicalResult PropagateShapeIntoAttachedRegions(Operation* op, + int64_t max_iterations); + // Propagates any constant operand of call_op to the called function body's // corresponding argument if the callee has only one use. // // TODO(b/154065712): Move this to a more general inter-procedural constant // folding pass. - void PropagateConstantToCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, ModuleOp module); + void PropagateConstantToCallee(CallOpInterface call_op, FuncOp func, + ModuleOp module); // Propagates any constant return value of the callee function to the call // op's corresponding result. - void PropagateConstantFromCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, ModuleOp module); + void PropagateConstantFromCallee(CallOpInterface call_op, FuncOp func, + ModuleOp module); // Tries to compute the result of folding the op. This doesn't actually // perform constant folding, it is just computes the equivalent constants. @@ -635,8 +684,8 @@ bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op, .isa()) continue; - std::get<1>(entry).setType(operand_type); - AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_, result_type); + UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, operand_type, op, + result); changed = true; } return changed; @@ -666,13 +715,12 @@ bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) { result_type.getShape() == operand_type.getShape()) continue; if (!is_allowed_dtype(operand_type.getElementType()) || - !is_allowed_dtype(result_type.getElementType())) { + !is_allowed_dtype(result_type.getElementType())) continue; - } - result.setType(RankedTensorType::get(operand_type.getShape(), - result_type.getElementType())); - AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_, result_type); + auto new_type = RankedTensorType::get(operand_type.getShape(), + result_type.getElementType()); + UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, new_type, op, result); changed = true; } return changed; @@ -712,7 +760,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { // The shape function of these ops sometimes does not propagate subtypes // (handle shapes) for resource and variant types. We use a simple passthrough // to make sure they are preserved in the output. - if (isa(op)) { + if (isa(op)) { return RefineTypeForPassThroughOperands(op, op->getOperands(), op->getResults()); } @@ -728,9 +777,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { // Handle call operations by looking up callee and infering return shape as // needed. - if (isa( - op)) - return InferShapeForCall(op); + if (auto call = dyn_cast(op)) return InferShapeForCall(call); // tf.Cast are only inferred if they have at least one user in the TF dialect // or feeding into the function return. This is necessary to avoid inserting @@ -738,6 +785,17 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { if (auto cast_op = dyn_cast(op)) return InferShapeForCast(cast_op, tf_dialect_); + // Handle IfOp here by inferring the shape from the else/then function + // results. Since `output_shapes` is a derived attribute, avoid going down the + // TF InferenceContext path as IfOp shape inference is implemented as just + // a lookup of the output_shapes attribute. + if (auto if_op = dyn_cast(op)) return InferShapeForIf(if_op); + + // Handle IfRegion operations by infering return shape from the then and else + // branches. + if (auto if_region = dyn_cast(op)) + return InferShapeForIfRegion(if_region); + StringRef op_name = op->getName().getStringRef(); // Drop the `tf.` prefix to query TF registry. auto node_name = @@ -910,12 +968,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { } auto new_type = get_tensor_type(shape_handle, new_element_type); if (result.getType() == new_type) continue; - // Inserts a cast back to the original type if any user is not in the TF - // dialect or a return. - AddCastBackForUnsupportedNonTFUses(op, result, tf_dialect_, - result.getType()); - // Finally we inferred the shape and replace the type for this result. - result.setType(new_type); + + UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, new_type, op, result); changed = true; } if (changed) @@ -924,59 +978,72 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { return changed; } -LogicalResult ShapeInference::RefineShapeForControlFlowFunc( - FuncOp func, ArrayRef input_types, int64_t max_iteration) { - ModuleOp module = func.getParentOfType(); - auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); - int num_uses = std::distance(func_uses->begin(), func_uses->end()); - if (num_uses != 1) { - func.emitWarning(formatv( - "expected control flow function {0} to have exactly 1 use, found {1}.", - func.getName(), num_uses)); - return failure(); - } - - FunctionType func_type = func.getType(); - func.setType(FunctionType::get(input_types, func_type.getResults(), - func.getContext())); - - for (auto arg_and_idx : llvm::enumerate(func.getArguments())) { - arg_and_idx.value().setType(input_types[arg_and_idx.index()]); - } - - auto res = InferShapeUntilFixPoint(&func.getBody(), max_iteration); - if (failed(res)) return res; - - auto new_return_types = InferShapeForFunctionReturnType(func); - if (new_return_types.hasValue()) { - func.setType(FunctionType::get(input_types, new_return_types.getValue(), - func.getContext())); - } - - return success(); -} - LogicalResult ShapeInference::PropagateShapeToFunctions( ModuleOp module, Operation::operand_type_range input_types, - ArrayRef func_names, int64_t max_iteration) { + ArrayRef functions, int64_t max_iteration) { bool all_succeeded = true; auto types = llvm::to_vector<4>(input_types); - for (auto func_name : func_names) { - FuncOp func = module.lookupSymbol(func_name); - all_succeeded = - succeeded(RefineShapeForControlFlowFunc(func, types, max_iteration)) && - all_succeeded; + // If shape propagation fails for one function, return failure, but do not + // early exit and attempt to propagate shapes for all provided functions to + // have a best-effort propagation. + for (FuncOp func : functions) { + auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); + if (!llvm::hasSingleElement(func_uses.getValue())) { + int num_uses = std::distance(func_uses->begin(), func_uses->end()); + func.emitWarning( + formatv("expected control flow function @{0} to have exactly 1 use, " + "found {1}.", + func.getName(), num_uses)); + all_succeeded = false; + continue; + } + + FunctionType func_type = func.getType(); + func.setType( + FunctionType::get(types, func_type.getResults(), func.getContext())); + + auto res = + PropagateShapeToRegions(input_types, {&func.getBody()}, max_iteration); + if (failed(res)) { + all_succeeded = false; + continue; + } + + auto new_return_types = InferShapeForFunctionReturnType(func); + if (new_return_types) + func.setType(FunctionType::get(types, new_return_types.getValue(), + func.getContext())); + } + return success(all_succeeded); +} + +LogicalResult ShapeInference::PropagateShapeToRegions( + Operation::operand_type_range input_types, ArrayRef regions, + int64_t max_iteration) { + bool all_succeeded = true; + auto types = llvm::to_vector<4>(input_types); + // If shape propagation fails for one region, return failure, but do not + // early exit and attempt to propagate shapes for all provided regions to + // have a best-effort propagation. + for (auto region : regions) { + // Refine region arguments. + Block& entry = region->front(); + assert(types.size() == entry.getNumArguments()); + for (auto arg_and_idx : llvm::enumerate(entry.getArguments())) { + arg_and_idx.value().setType(types[arg_and_idx.index()]); + } + + // Propagate shapes into the region. + all_succeeded = succeeded(InferShapeUntilFixPoint(region, max_iteration)) && + all_succeeded; } return success(all_succeeded); } void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, - ModuleOp module) { - auto func = module.lookupSymbol(callee_sym.getRootReference()); + FuncOp func, ModuleOp module) { auto func_uses = SymbolTable::getSymbolUses(func, &module.getBodyRegion()); - int num_uses = std::distance(func_uses->begin(), func_uses->end()); - if (num_uses != 1) return; + if (!llvm::hasSingleElement(func_uses.getValue())) return; OpBuilder builder(&func.front().front()); Operation* op = call_op.getOperation(); @@ -1002,9 +1069,7 @@ void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op, } void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op, - SymbolRefAttr callee_sym, - ModuleOp module) { - auto func = module.lookupSymbol(callee_sym.getRootReference()); + FuncOp func, ModuleOp module) { // If the return value is a constant, use the constant as the value of // the call return. Operation* op = call_op.getOperation(); @@ -1036,28 +1101,29 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( if (auto if_op = dyn_cast(op)) { return PropagateShapeToFunctions( module, drop_begin(if_op.getOperandTypes(), 1), - {if_op.then_branch(), if_op.else_branch()}, max_iteration); + {if_op.then_func(), if_op.else_func()}, max_iteration); } else if (auto case_op = dyn_cast(op)) { - SmallVector branches; - for (Attribute branch : case_op.branches()) - branches.push_back(branch.cast().getValue()); + SmallVector branches; + for (Attribute branch : case_op.branches()) { + auto sym = branch.cast(); + branches.push_back(SymbolTable::lookupNearestSymbolFrom(op, sym)); + } return PropagateShapeToFunctions(module, drop_begin(case_op.getOperandTypes(), 1), branches, max_iteration); } else if (auto while_op = dyn_cast(op)) { - return PropagateShapeToFunctions(module, while_op.getOperandTypes(), - {while_op.cond(), while_op.body()}, - max_iteration); + return PropagateShapeToFunctions( + module, while_op.getOperandTypes(), + {while_op.cond_func(), while_op.body_func()}, max_iteration); } else if (auto call_op = dyn_cast(op)) { - CallInterfaceCallable callable = call_op.getCallableForCallee(); - if (SymbolRefAttr sym = callable.dyn_cast()) { - PropagateConstantToCallee(call_op, sym, module); - if (failed(PropagateShapeToFunctions( - module, call_op.getArgOperands().getTypes(), - {sym.getRootReference()}, max_iteration))) { + if (auto func = dyn_cast(call_op.resolveCallable())) { + PropagateConstantToCallee(call_op, func, module); + if (failed(PropagateShapeToFunctions(module, + call_op.getArgOperands().getTypes(), + {func}, max_iteration))) { return failure(); } - PropagateConstantFromCallee(call_op, sym, module); + PropagateConstantFromCallee(call_op, func, module); return success(); } } @@ -1067,6 +1133,16 @@ LogicalResult ShapeInference::PropagateShapeIntoAttachedFunctions( return success(); } +LogicalResult ShapeInference::PropagateShapeIntoAttachedRegions( + Operation* op, int64_t max_iteration) { + if (auto while_op = dyn_cast(op)) { + return PropagateShapeToRegions(while_op.getOperandTypes(), + {&while_op.cond(), &while_op.body()}, + max_iteration); + } + return success(); +} + LogicalResult ShapeInference::TryToFold(Operation* op) { LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n"); // If any output result is known, then the op probably has been computed @@ -1118,12 +1194,8 @@ LogicalResult ShapeInference::TryToFold(Operation* op) { if (ElementsAttr eattr = attr.dyn_cast_or_null()) { if (std::get<0>(result).getType() == eattr.getType()) continue; - // Inserts a cast back to the original type if any user is not in the - // TF dialect. - Type old_type = std::get<0>(result).getType(); - std::get<0>(result).setType(eattr.getType()); - AddCastBackForUnsupportedNonTFUses(op, std::get<0>(result), tf_dialect_, - old_type); + UpdateTypeAndInsertIncompatibleUseCasts(tf_dialect_, eattr.getType(), op, + std::get<0>(result)); } } @@ -1164,6 +1236,11 @@ LogicalResult ShapeInference::InferShapeUntilFixPoint(Region* region, "arguments and bodies"; } + if (failed(PropagateShapeIntoAttachedRegions(op, max_iteration))) { + op->emitWarning() << "unable to refine shape of attached region " + "arguments and bodies"; + } + changed |= InferShapeForSingleOperation(op); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index 5e095a311ee..d3755a4a7d0 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -163,7 +163,7 @@ LogicalResult HandleWhileOp( const llvm::SmallDenseMap& data_var_to_size_var, llvm::StringMap* decomposed_partitioned_call_callees) { - auto body = module.lookupSymbol(while_op.body()); + auto body = while_op.body_func(); llvm::SmallDenseMap body_map; auto find_arg_stack_type = [&](int64_t index) -> llvm::Optional { auto it = data_var_to_size_var.find(while_op.getOperand(index)); @@ -187,7 +187,7 @@ LogicalResult HandleWhileOp( return failure(); } // Cond should not change stacks in the arguments, so use an empty map. - auto cond = module.lookupSymbol(while_op.cond()); + auto cond = while_op.cond_func(); ModifyFunctionSignature(cond, nullptr, find_arg_stack_type); llvm::SmallDenseMap empty_map; if (failed(DecomposeStackOpsInternal(&cond.front(), module, &empty_map, @@ -197,24 +197,16 @@ LogicalResult HandleWhileOp( if (!signature_change) return success(); // Create the new while op. auto new_while_operands = llvm::to_vector<8>(while_op.getOperands()); - auto new_output_shapes = - llvm::to_vector<8>(while_op.output_shapes().getValue()); OpBuilder builder(while_op); assert(while_op.getNumOperands() == while_op.getNumResults()); for (int64_t i = 0; i < while_op.getNumResults(); ++i) { auto it = data_var_to_size_var.find(while_op.getOperand(i)); if (it == data_var_to_size_var.end()) continue; new_while_operands.push_back(it->getSecond()); - if (!new_output_shapes.empty()) { - // Size is a scalar shape. - new_output_shapes.push_back( - mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef())); - } } auto new_while = builder.create(while_op.getLoc(), body.getType().getInputs(), new_while_operands, while_op.getAttrs()); - new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); for (int64_t i = 0; i < while_op.getNumResults(); ++i) { if (!getElementTypeOrSelf(while_op.getOperand(i).getType()) .isa()) { @@ -239,8 +231,8 @@ LogicalResult HandleIfOp( const llvm::SmallDenseMap& data_var_to_size_var, llvm::StringMap* decomposed_partitioned_call_callees) { - auto then_branch = module.lookupSymbol(if_op.then_branch()); - auto else_branch = module.lookupSymbol(if_op.else_branch()); + auto then_func = if_op.then_func(); + auto else_func = if_op.else_func(); llvm::SmallDenseMap then_map; llvm::SmallDenseMap else_map; @@ -249,12 +241,12 @@ LogicalResult HandleIfOp( if (it == data_var_to_size_var.end()) return llvm::None; return it->getFirst().getType(); }; - ModifyFunctionSignature(then_branch, &then_map, find_arg_stack_type); - ModifyFunctionSignature(else_branch, &else_map, find_arg_stack_type); + ModifyFunctionSignature(then_func, &then_map, find_arg_stack_type); + ModifyFunctionSignature(else_func, &else_map, find_arg_stack_type); const bool signature_change = !then_map.empty() || !else_map.empty(); - if (failed(DecomposeStackOpsInternal(&then_branch.front(), module, &then_map, + if (failed(DecomposeStackOpsInternal(&then_func.front(), module, &then_map, decomposed_partitioned_call_callees)) || - failed(DecomposeStackOpsInternal(&else_branch.front(), module, &else_map, + failed(DecomposeStackOpsInternal(&else_func.front(), module, &else_map, decomposed_partitioned_call_callees))) { return failure(); } @@ -266,16 +258,16 @@ LogicalResult HandleIfOp( new_if_operands.push_back(it->getSecond()); } auto new_if = OpBuilder(if_op).create( - if_op.getLoc(), then_branch.getType().getResults(), new_if_operands, + if_op.getLoc(), then_func.getType().getResults(), new_if_operands, if_op.getAttrs()); for (auto result : if_op.getResults()) { if (!getElementTypeOrSelf(result.getType()).isa()) { continue; } int64_t then_aliased_input = - FindAliasedInput(then_branch, result.getResultNumber()); + FindAliasedInput(then_func, result.getResultNumber()); int64_t else_aliased_input = - FindAliasedInput(else_branch, result.getResultNumber()); + FindAliasedInput(else_func, result.getResultNumber()); if (then_aliased_input >= 0 && then_aliased_input == else_aliased_input) { // Replace aliased stack output uses with input. result.replaceAllUsesWith(if_op.getOperand(then_aliased_input + 1)); @@ -409,11 +401,9 @@ LogicalResult HandleStackV2Op( ArrayRef{buffer.getType().cast()}, stack.getContext())); auto local_var = builder.create( - stack.getLoc(), ArrayRef{var_type}, ArrayRef{}, - ArrayRef{}); + stack.getLoc(), ArrayRef{var_type}, ArrayRef{}); auto local_size_var = builder.create( - stack.getLoc(), ArrayRef{size_var_type}, ArrayRef{}, - ArrayRef{}); + stack.getLoc(), ArrayRef{size_var_type}, ArrayRef{}); // Zero-initialize the local vars. cutil::WriteLocalVariable(local_size_var, cutil::GetR1Const({0LL}, builder, stack.getLoc()), @@ -446,8 +436,7 @@ LogicalResult HandleStackPushV2Op( cutil::WriteLocalVariable(push.handle(), stack_val, builder, push.getLoc()); index = builder.create( push.getLoc(), ArrayRef{index.getType()}, - ArrayRef{index, cutil::GetR1Const({1}, builder, push.getLoc())}, - ArrayRef{}); + ArrayRef{index, cutil::GetR1Const({1}, builder, push.getLoc())}); cutil::WriteLocalVariable(it->getSecond(), index, builder, push.getLoc()); push.erase(); return success(); @@ -467,8 +456,7 @@ LogicalResult HandleStackPopV2Op( auto size = cutil::ReadLocalVariable(it->getSecond(), builder, pop.getLoc()); auto new_size = builder.create( pop.getLoc(), ArrayRef{size.getType()}, - ArrayRef{size, cutil::GetR1Const({1}, builder, pop.getLoc())}, - ArrayRef{}); + ArrayRef{size, cutil::GetR1Const({1}, builder, pop.getLoc())}); auto pop_val = cutil::GetElement(new_size, stack_val, builder, pop.getLoc()); pop.replaceAllUsesWith(pop_val); // Update the size. @@ -519,21 +507,20 @@ LogicalResult DecomposeStackOpsInternal( return failure(); } } else if (auto pcall = llvm::dyn_cast(&op)) { - if (!pcall.f().isa()) { + if (!pcall.func()) { return pcall.emitOpError( "stack decomposition does not support call with nested references"); } if (failed(HandlePartitionedCallOp( - pcall, module.lookupSymbol(pcall.f().getRootReference()), - module, *data_var_to_size_var, + pcall, pcall.func(), module, *data_var_to_size_var, decomposed_partitioned_call_callees))) { return failure(); } } else if (auto spcall = llvm::dyn_cast(&op)) { if (failed(HandlePartitionedCallOp( - spcall, module.lookupSymbol(spcall.f()), module, - *data_var_to_size_var, decomposed_partitioned_call_callees))) { + spcall, spcall.func(), module, *data_var_to_size_var, + decomposed_partitioned_call_callees))) { return failure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index 9c659a95078..b3a05c06a67 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -166,8 +166,7 @@ LogicalResult HandleTensorArrayV3Op( ArrayRef{buffer.getType().cast()}, ta.getContext())); auto local_var = builder.create( - ta.getLoc(), ArrayRef{var_type}, ArrayRef{}, - ArrayRef{}); + ta.getLoc(), ArrayRef{var_type}, ArrayRef{}); cutil::WriteLocalVariable(local_var, buffer, builder, ta.getLoc()); ta.handle().replaceAllUsesWith(local_var); // The flow output is just a way for the front end to enforce ordering among @@ -227,8 +226,7 @@ LogicalResult HandleTensorArrayWriteV3Op( elem = builder.create( write.getLoc(), ArrayRef{slice_type}, ArrayRef{elem, cutil::GetR1Const(slice_type.getShape(), builder, - write.getLoc())}, - ArrayRef{}); + write.getLoc())}); elem = cutil::AccumulateBuffers(elem, original_elem, builder, write.getLoc()); } @@ -261,8 +259,7 @@ LogicalResult HandleTensorArrayConcatV3Op( ArrayRef{ RankedTensorType::get(shape, buffer_type.getElementType())}, ArrayRef{buffer, - cutil::GetR1Const(shape, builder, concat.getLoc())}, - ArrayRef{}); + cutil::GetR1Const(shape, builder, concat.getLoc())}); concat.value().replaceAllUsesWith(buffer); // Create the lengths as a list of the same value (element size). @@ -302,8 +299,7 @@ LogicalResult HandleTensorArraySplitV3Op( buffer_shape, elem_type.getElementType())}, ArrayRef{split.value(), cutil::GetR1Const(buffer_shape, builder, - split.getLoc())}, - ArrayRef{}) + split.getLoc())}) .output(); // Accumulate with the old buffer. auto old_buffer = @@ -339,8 +335,7 @@ LogicalResult CreateAndInitializeGradVariable(Type local_var_type, Operation* op, Value* var) { OpBuilder builder(op); *var = builder.create( - op->getLoc(), ArrayRef{local_var_type}, ArrayRef{}, - ArrayRef{}); + op->getLoc(), ArrayRef{local_var_type}, ArrayRef{}); Value buffer; auto buffer_type = getElementTypeOrSelf(local_var_type) .cast() @@ -447,38 +442,20 @@ llvm::SmallDenseMap> AccessedGradients( if (auto grad = llvm::dyn_cast(&op)) { insert(grad.handle(), grad.source().str()); } else if (auto while_op = llvm::dyn_cast(&op)) { - auto body = module.lookupSymbol(while_op.body()); - auto cond = module.lookupSymbol(while_op.cond()); - for (const auto& entry : AccessedGradients({body, cond}, module)) { - for (const string& source : entry.getSecond()) { + for (const auto& entry : AccessedGradients( + {while_op.body_func(), while_op.cond_func()}, module)) + for (const string& source : entry.getSecond()) insert(while_op.getOperand(entry.getFirst()), source); - } - } } else if (auto if_op = llvm::dyn_cast(&op)) { - auto then_branch = module.lookupSymbol(if_op.then_branch()); - auto else_branch = module.lookupSymbol(if_op.else_branch()); for (const auto& entry : - AccessedGradients({then_branch, else_branch}, module)) { - for (const string& source : entry.getSecond()) { + AccessedGradients({if_op.then_func(), if_op.else_func()}, module)) + for (const string& source : entry.getSecond()) insert(if_op.getOperand(entry.getFirst() + 1), source); - } - } - } else if (auto pc = llvm::dyn_cast(&op)) { - if (!pc.f().isa()) continue; - auto callee = module.lookupSymbol(pc.f().getRootReference()); - for (const auto& entry : AccessedGradients({callee}, module)) { - for (const string& source : entry.getSecond()) { - insert(pc.getOperand(entry.getFirst()), source); - } - } - } else if (auto spc = - llvm::dyn_cast(&op)) { - auto callee = module.lookupSymbol(spc.f()); - for (const auto& entry : AccessedGradients({callee}, module)) { - for (const string& source : entry.getSecond()) { - insert(spc.getOperand(entry.getFirst()), source); - } - } + } else if (auto call = llvm::dyn_cast(&op)) { + auto callee = dyn_cast(call.resolveCallable()); + for (const auto& entry : AccessedGradients({callee}, module)) + for (const string& source : entry.getSecond()) + insert(call.getArgOperands()[entry.getFirst()], source); } } } @@ -532,8 +509,8 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, llvm::SmallDenseMap* stats, llvm::StringMap* decomposed_partitioned_call_callees) { - auto body = module.lookupSymbol(while_op.body()); - auto cond = module.lookupSymbol(while_op.cond()); + auto body = while_op.body_func(); + auto cond = while_op.cond_func(); auto grads = AccessedGradients({body, cond}, module); auto ta_arg_buffer_type = [&](int64_t index) -> Type { auto it = stats->find(while_op.getOperand(index)); @@ -600,8 +577,6 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, auto new_while = builder.create(while_op.getLoc(), body.getType().getInputs(), operands, while_op.getAttrs()); - // Clear the output shapes as it is not needed for XLA lowering. - new_while.setAttr("output_shapes", builder.getArrayAttr({})); for (int64_t i = 0; i < while_op.getNumOperands(); ++i) { if (ta_arg_buffer_type(i)) { while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i)); @@ -617,8 +592,8 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, llvm::SmallDenseMap* stats, llvm::StringMap* decomposed_partitioned_call_callees) { - auto then_branch = module.lookupSymbol(if_op.then_branch()); - auto else_branch = module.lookupSymbol(if_op.else_branch()); + auto then_branch = if_op.then_func(); + auto else_branch = if_op.else_func(); auto grads = AccessedGradients({then_branch, else_branch}, module); auto ta_arg_buffer_type = [&](int64_t index) -> Type { auto it = stats->find(if_op.getOperand(index + 1)); @@ -668,8 +643,6 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, auto new_if = builder.create(if_op.getLoc(), then_branch.getType().getResults(), operands, if_op.getAttrs()); - // Clear the output shapes as it is not needed for XLA lowering. - new_if.setAttr("output_shapes", builder.getArrayAttr({})); auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t { auto retval = f.front().getTerminator()->getOperand(ret_ind); auto arg = retval.dyn_cast(); @@ -847,21 +820,22 @@ LogicalResult DecomposeTensorArrayOps( return failure(); } } else if (auto pcall = llvm::dyn_cast(&op)) { - if (!pcall.f().isa()) { + auto callee = pcall.func(); + if (!callee) return pcall.emitOpError( "TensorArray decomposition does not support call with nested " "references."); - } - if (failed(HandlePartitionedCallOp( - pcall, module.lookupSymbol(pcall.f().getRootReference()), - module, stats, decomposed_partitioned_call_callees))) { + + if (failed( + HandlePartitionedCallOp(pcall, callee, module, stats, + decomposed_partitioned_call_callees))) { return failure(); } } else if (auto spcall = llvm::dyn_cast(&op)) { - if (failed(HandlePartitionedCallOp( - spcall, module.lookupSymbol(spcall.f()), module, stats, - decomposed_partitioned_call_callees))) { + if (failed( + HandlePartitionedCallOp(spcall, spcall.func(), module, stats, + decomposed_partitioned_call_callees))) { return failure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 11153f0dfc3..9634e4a8be3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -155,7 +155,7 @@ LogicalResult HandleWhileOp( llvm::StringMap* decomposed_partitioned_call_callees) { // Rewrite body. - auto body = module.lookupSymbol(while_op.body()); + auto body = while_op.body_func(); llvm::SmallDenseMap body_map; auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional { auto it = buffer_to_size->find(while_op.getOperand(index)); @@ -176,7 +176,7 @@ LogicalResult HandleWhileOp( auto output_buffer_to_size = AddTensorListSizesToReturn(body, body_map); // Rewrite cond. - auto cond = module.lookupSymbol(while_op.cond()); + auto cond = while_op.cond_func(); llvm::SmallDenseMap cond_map; ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map, find_arg_tensor_list_type, arg_buffer_size_is_fixed); @@ -190,22 +190,14 @@ LogicalResult HandleWhileOp( } // Create the new while op. auto new_while_operands = llvm::to_vector<8>(while_op.getOperands()); - auto new_output_shapes = - llvm::to_vector<8>(while_op.output_shapes().getValue()); for (int64_t i = 0; i < while_op.getNumResults(); ++i) { auto it = buffer_to_size->find(while_op.getOperand(i)); if (it == buffer_to_size->end()) continue; new_while_operands.push_back(it->getSecond().size); - if (!new_output_shapes.empty()) { - // Size is a scalar shape. - new_output_shapes.push_back( - mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef())); - } } auto new_while = builder.create(while_op.getLoc(), body.getType().getInputs(), new_while_operands, while_op.getAttrs()); - new_while.setAttr("output_shapes", builder.getArrayAttr(new_output_shapes)); for (const auto& entry : output_buffer_to_size) { (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = { new_while.getResult(std::get<1>(entry)), std::get<2>(entry)}; @@ -438,7 +430,7 @@ LogicalResult HandleTensorListFromTensorOp( OpBuilder builder(list); Value buffer = builder.create( list.getLoc(), ArrayRef{list.tensor().getType()}, - ArrayRef{list.tensor()}, ArrayRef{}); + ArrayRef{list.tensor()}); auto type = buffer.getType().cast(); if (!type.hasStaticShape()) { return list.emitOpError("TensorListFromTensorOp input has unknown shape."); @@ -468,8 +460,7 @@ LogicalResult HandleTensorListPushBackOp( cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc()); auto new_size = builder.create( push.getLoc(), ArrayRef{size.getType()}, - ArrayRef{size, cutil::GetR1Const({1LL}, builder, push.getLoc())}, - ArrayRef{}); + ArrayRef{size, cutil::GetR1Const({1LL}, builder, push.getLoc())}); push.output_handle().replaceAllUsesWith(new_buffer); (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false}; push.erase(); @@ -491,12 +482,10 @@ LogicalResult HandleTensorListPopBackOp( auto size = it->getSecond().size; OpBuilder builder(pop); auto new_buffer = builder.create( - pop.getLoc(), ArrayRef{buffer.getType()}, ArrayRef{buffer}, - ArrayRef{}); + pop.getLoc(), ArrayRef{buffer.getType()}, ArrayRef{buffer}); auto new_size = builder.create( pop.getLoc(), ArrayRef{size.getType()}, - ArrayRef{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())}, - ArrayRef{}); + ArrayRef{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())}); auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc()); pop.output_handle().replaceAllUsesWith(new_buffer); pop.tensor().replaceAllUsesWith(element); @@ -567,8 +556,7 @@ LogicalResult HandleTensorListLengthOp( ArrayRef{RankedTensorType::get( {}, getElementTypeOrSelf(current_size.getType()))}, ArrayRef{current_size, - cutil::GetR1Const({}, builder, length.getLoc())}, - ArrayRef{}); + cutil::GetR1Const({}, builder, length.getLoc())}); length.length().replaceAllUsesWith(reshape); } length.erase(); @@ -713,11 +701,8 @@ LogicalResult DecomposeTensorListOpsInternal( return failure(); } } else if (auto if_op = llvm::dyn_cast(&op)) { - auto then_branch = module.lookupSymbol(if_op.then_branch()); - auto else_branch = module.lookupSymbol(if_op.else_branch()); - - if (failed(HandleCaseOrIfOp(if_op, {then_branch, else_branch}, module, - buffer_to_size, + if (failed(HandleCaseOrIfOp(if_op, {if_op.then_func(), if_op.else_func()}, + module, buffer_to_size, decomposed_partitioned_call_callees))) { return failure(); } @@ -732,21 +717,21 @@ LogicalResult DecomposeTensorListOpsInternal( return failure(); } } else if (auto pcall = llvm::dyn_cast(&op)) { - if (!pcall.f().isa()) { + if (!pcall.func()) return pcall.emitOpError( "TensorList decomposition does not support call with nested " "references."); - } + if (failed(HandlePartitionedCallOp( - pcall, module.lookupSymbol(pcall.f().getRootReference()), - module, buffer_to_size, decomposed_partitioned_call_callees))) { + pcall, pcall.func(), module, buffer_to_size, + decomposed_partitioned_call_callees))) { return failure(); } } else if (auto spcall = llvm::dyn_cast(&op)) { if (failed(HandlePartitionedCallOp( - spcall, module.lookupSymbol(spcall.f()), module, - buffer_to_size, decomposed_partitioned_call_callees))) { + spcall, spcall.func(), module, buffer_to_size, + decomposed_partitioned_call_callees))) { return failure(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_resource_alias_analysis.cc new file mode 100644 index 00000000000..920b2024c0f --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_resource_alias_analysis.cc @@ -0,0 +1,111 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TF { +namespace { + +// A pass that annotates each operation with a resource type result with the +// aliasing values for each such result. Each value is assigned a unique ID, and +// that ID is used to annotate the operations. +struct TestResourceAliasAnalysis + : public TF::PerFunctionAggregateAnalysisConsumerPass< + TestResourceAliasAnalysis, TF::ResourceAliasAnalysis> { + void runOnFunction(FuncOp func, + const TF::ResourceAliasAnalysis::Info& analysis) { + int64_t next_id = 0; + llvm::SmallDenseMap ids; + + auto assign_id = [&](Value value) { + if (ids.find(value) == ids.end()) ids.insert({value, next_id++}); + }; + + auto get_id = [&](Value value) -> int64_t { + auto it = ids.find(value); + assert(it != ids.end()); + return it->second; + }; + + auto print_aliases = [&](InFlightDiagnostic& diag, Value value) { + diag << ", ID " << get_id(value) << " : "; + if (analysis.IsUnknownResource(value)) { + diag << "Unknown"; + } else { + auto aliases = llvm::to_vector<4>(analysis.GetResourceAliases(value)); + llvm::sort(aliases, + [&](Value v1, Value v2) { return get_id(v1) < get_id(v2); }); + llvm::interleaveComma(aliases, diag, + [&](Value v) { diag << get_id(v); }); + } + }; + + // Assign a unique ID to each value seen in this function. + func.walk([&](Operation* op) { + // For all attached regions, assign ID to the region arguments. + for (Region& region : op->getRegions()) { + for (auto region_arg : filter_resources(region.getArguments())) + assign_id(region_arg); + } + + // Assign ID for all results. + for (auto result : filter_resources(op->getResults())) assign_id(result); + }); + + // Now walk each operation, and annotate it wil remarks for aliases for + // each resource type result + func.walk([&](Operation* op) { + // For all attached regions, assign ID to the region arguments. + for (Region& region : op->getRegions()) { + for (auto region_arg : filter_resources(region.getArguments())) { + InFlightDiagnostic diag = op->emitRemark("Region #") + << region.getRegionNumber() << ", Arg #" + << region_arg.getArgNumber(); + print_aliases(diag, region_arg); + } + } + + for (auto result : filter_resources(op->getResults())) { + InFlightDiagnostic diag = op->emitRemark("Result #") + << result.getResultNumber(); + print_aliases(diag, result); + } + }); + } +}; + +static mlir::PassRegistration pass( + "tf-test-resource-alias-analysis", + "Add remarks based on resource alias analysis result, for testing " + "purpose."); + +} // anonymous namespace +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc index 6b284222526..405c529840b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/test_side_effect_analysis.cc @@ -39,11 +39,13 @@ namespace { // A pass that adds "Predecessors" and "Successors" remarks for each op based on // SideEffectAnalysis result. For testing purpose only. struct TestSideEffectAnalysis - : public mlir::PassWrapper { - void runOnFunction() override { + : public TF::PerFunctionAggregateAnalysisConsumerPass< + TestSideEffectAnalysis, TF::SideEffectAnalysis> { + void runOnFunction(FuncOp func, + const TF::SideEffectAnalysis::Info& analysis) { int64_t next_id = 0; llvm::SmallDenseMap ids; - getFunction().walk([&](Operation* op) { + func.walk([&](Operation* op) { ids[op] = next_id++; op->emitRemark("ID: ") << ids[op]; }); @@ -53,8 +55,7 @@ struct TestSideEffectAnalysis for (auto op : ops) id_vec.push_back(std::to_string(ids[op])); return llvm::join(id_vec, ","); }; - auto& analysis = getAnalysis(); - getFunction().walk([&](Operation* op) { + func.walk([&](Operation* op) { if (!analysis.DirectControlPredecessors(op).empty()) { op->emitRemark("Predecessors: ") << "{" << join_ids(analysis.DirectControlPredecessors(op)) << "}"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h index f7a73dc1561..d46b81156f9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h @@ -46,6 +46,9 @@ CreateRemoveVariablesInSessionInitializerPass(); std::unique_ptr> CreateLiftVariablesPass( ::tensorflow::Session* session); +// Creates a pass that removes duplicate 'tf_saved_model.bound_input' bindings. +std::unique_ptr> CreateDedupBoundInputBindingPass(); + } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index 9abf67b62a9..162ecd77d4f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -344,8 +344,9 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { for (auto& pos_and_input : llvm::enumerate(replicated_input_ops)) { auto input = pos_and_input.value(); bool is_packed = llvm::cast(input).is_packed(); + const int num_operands = input->getNumOperands(); int num_inputs = is_packed ? 1 : num_replicas; - if (input->getNumOperands() != num_inputs) + if (num_operands != num_inputs) return input->emitOpError() << "requires " << num_inputs << " operands"; auto tpu_replicated_input = llvm::cast(input); @@ -393,7 +394,8 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) { << "requires output of " << cluster.getOperationName() << " to lead to a 'tf.TPUReplicatedOutput' op"; - if (def->getNumResults() != num_replicas) + const int def_NumResults = def->getNumResults(); + if (def_NumResults != num_replicas) return def->emitOpError() << "requires " << num_replicas << " results"; auto replicate_outputs = llvm::make_range( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index e2f4fca1219..41362465cd9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" @@ -77,24 +78,28 @@ constexpr char kFuncDeviceAttr[] = "tf.device"; // because tf.TPUCopyWithLayout accepts a host input and produces a device // output. struct TPUDynamicLayoutPass - : public PassWrapper { - void runOnFunction() override; + : public TF::PerFunctionAggregateAnalysisConsumerPass< + TPUDynamicLayoutPass, TF::ResourceAliasAnalysis> { + void runOnFunction( + FuncOp func, + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis); }; // Checks if the input producer op is supported in this transform. Right now, we // only check if it is a tf.IteratorGetNext where resource input is coming from // a VarHandle on CPU or a function argument assigned to CPU. -bool IsSupportedInputOp(Operation* op, - TF::ResourceAliasAnalysis* resource_alias_analysis) { +bool IsSupportedInputOp( + Operation* op, + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { TF::IteratorGetNextOp iterator_op = llvm::dyn_cast(op); if (!iterator_op) return false; Value resource_iterator = iterator_op.iterator(); - if (resource_alias_analysis->IsUnknownResource(resource_iterator)) + if (resource_alias_analysis.IsUnknownResource(resource_iterator)) return false; llvm::SmallSetVector aliases = - resource_alias_analysis->GetResourceAliases(resource_iterator); + resource_alias_analysis.GetResourceAliases(resource_iterator); auto is_generator = [](Value val) { if (val.isa()) return true; @@ -154,8 +159,7 @@ TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch, Value input, OpBuilder* builder) { return builder->create( execute_launch.getLoc(), llvm::ArrayRef{input.getType()}, - llvm::ArrayRef{input, get_layout.layout()}, - llvm::ArrayRef{}); + llvm::ArrayRef{input, get_layout.layout()}); } // Performs transformation for a non-replicated input. @@ -178,7 +182,7 @@ bool HandleReplicatedInputs( const int64_t execute_arg_index, Value compilation_key, tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch, const int64_t replicate_arg_index, tf_device::ReplicateOp replicate, - TF::ResourceAliasAnalysis* resource_alias_analysis) { + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { // We need to know the devices to copy to. if (!replicate.devices()) return false; int64_t num_replicas = replicate.n().getZExtValue(); @@ -216,7 +220,7 @@ bool HandleReplicatedInputs( void HandleCompileAndExecutes( tf_device::LaunchOp compile_launch, llvm::MutableArrayRef execute_launches, - TF::ResourceAliasAnalysis* resource_alias_analysis) { + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { auto compile = llvm::cast(compile_launch.GetBody().front()); tensorflow::tpu::TPUCompileMetadataProto metadata; @@ -274,9 +278,10 @@ void HandleCompileAndExecutes( compile.getContext())); } -void TPUDynamicLayoutPass::runOnFunction() { - TF::ResourceAliasAnalysis resource_alias_analysis(getFunction()); - getFunction().walk([&](TF::_TPUCompileMlirOp compile) { +void TPUDynamicLayoutPass::runOnFunction( + FuncOp func, + const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) { + func.walk([&](TF::_TPUCompileMlirOp compile) { // Detect tf._TPUCompileMlir -> tf.TPUExecute(s). auto compile_launch = llvm::dyn_cast(compile.getParentOp()); @@ -296,13 +301,13 @@ void TPUDynamicLayoutPass::runOnFunction() { } HandleCompileAndExecutes(compile_launch, execute_launches, - &resource_alias_analysis); + resource_alias_analysis); }); } } // namespace -std::unique_ptr> CreateTPUDynamicLayoutPass() { +std::unique_ptr> CreateTPUDynamicLayoutPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index af0675197ac..9365807663a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -17,11 +17,21 @@ limitations under the License. #include #include +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -34,10 +44,7 @@ namespace TFTPU { namespace { -constexpr char kAncestorsAttr[] = "ancestors"; constexpr char kDeviceAttr[] = "device"; -constexpr char kKeyAttr[] = "key"; -constexpr char kShapesAttr[] = "shapes"; constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; // Mapping for `_xla_outside_compilation` attribute to ops of a cluster. @@ -80,31 +87,203 @@ struct TPUExtractOutsideCompilation void runOnOperation() override; }; -// Collects and clusters ops in `block` with the same `_xla_outside_compilation` -// attribute into `clusters` This returns an error if a -// `_xla_outside_compilation` attribute of an op is empty. -LogicalResult CollectAndGroupOutsideClusterOps(Block* block, - OutsideClusterMap* clusters) { - for (Operation& op : *block) { - if (auto attr = op.getAttrOfType(kXlaOutsideCompilationAttr)) { - if (attr.getValue().empty()) - return op.emitError() - << "attribute '" << kXlaOutsideCompilationAttr << "' is empty"; +// Holds information about control flow operations that wrap outside compiled +// op. Currently only tf.If op is supported. +class ControlFlowStackInfo { + public: + enum ControlFlowBranchType { kIfThen, kIfElse }; - auto it = clusters->try_emplace(attr.getValue()); - it.first->getSecond().push_back(&op); + explicit ControlFlowStackInfo(Operation* wrapping_op, Operation* nested_op) + : callsite_op_(wrapping_op) { + // Only tf.IfRegion op is supported for now. + auto control_flow_op = llvm::cast(callsite_op_); + assert(control_flow_op); + + auto parent_region = nested_op->getParentRegion(); + if (&control_flow_op.then_branch() == parent_region) { + type_ = ControlFlowBranchType::kIfThen; + } else { + type_ = ControlFlowBranchType::kIfElse; } } - return success(); + Value GetIfPredicateValue() { + auto if_op = llvm::cast(callsite_op_); + return if_op.cond(); + } + + ControlFlowBranchType GetBranchType() const { return type_; } + + Operation* GetCallSiteOp() const { return callsite_op_; } + + private: + ControlFlowBranchType type_; + + // `this` does not hold ownership of `callsite_op_`. + Operation* callsite_op_; +}; + +// Returns a list of ControlFlowStackInfo that represents a stack of control +// flow operations that wraps `op`. +llvm::SmallVector GetControlFlowStackForOp( + tf_device::ClusterOp tpu_cluster, Operation* op) { + assert(tpu_cluster.getOperation()->isProperAncestor(op)); + + llvm::SmallVector controlflow_stack; + Operation* op_in_stack = op; + while (op_in_stack != tpu_cluster.getOperation()) { + auto parent_op = op_in_stack->getParentOp(); + if (llvm::isa(parent_op)) { + controlflow_stack.insert(controlflow_stack.begin(), + ControlFlowStackInfo(parent_op, op_in_stack)); + } + op_in_stack = parent_op; + } + + return controlflow_stack; } -// Moves `cluster_ops` to associated `launch_op` body. -void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op, - llvm::ArrayRef cluster_ops) { - MLIRContext* context = launch_op.getContext(); - Operation* terminator = launch_op.GetBody().getTerminator(); +// Creates a IfRegionOp with `predicate` and then/else region with yield op and +// an empty block. +TF::IfRegionOp CloneEmptyIfWithPredicate(Value predicate, bool is_stateless, + Location loc, OpBuilder* builder) { + auto host_side_if = builder->create( + loc, llvm::SmallVector{}, predicate, is_stateless); + // Create empty then branch region. + auto& then_branch = host_side_if.then_branch(); + builder->setInsertionPoint(&then_branch.front(), then_branch.front().begin()); + builder->createBlock(&then_branch); + builder->create(loc, llvm::SmallVector({})); + + // Create empty else branch region. + auto& else_branch = host_side_if.else_branch(); + builder->setInsertionPoint(&else_branch.front(), else_branch.front().begin()); + builder->createBlock(&else_branch); + builder->create(loc, llvm::SmallVector({})); + return host_side_if; +} + +// Replicates tf.IfRegion op to host side computation. +Operation* ReplicateIf(const ControlFlowStackInfo& controlflow_info, + llvm::StringRef outside_cluster_name, ModuleOp module, + Value compilation_key, OpBuilder* builder, + int* send_recv_counter) { + // Create XlaSendToHostOp to send predicate value from device to host. + OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint(); + auto if_callsite_op = + llvm::cast(controlflow_info.GetCallSiteOp()); + builder->setInsertionPoint(if_callsite_op); + + const auto predicate_send_recv_key = + llvm::formatv("if_predicate_channel_{0}_{1}", outside_cluster_name, + *send_recv_counter) + .str(); + *send_recv_counter += 1; + + auto predicate = if_callsite_op.cond(); + auto predicate_shape = predicate.getType(); + builder->create(if_callsite_op.getLoc(), predicate, + predicate_send_recv_key); + + // Create XlaRecvAtHostOp to receive predicate value from host. + builder->restoreInsertionPoint(insert_point); + auto recv_predicate_at_host = builder->create( + if_callsite_op.getLoc(), llvm::ArrayRef{predicate_shape}, + /*dynamic_key=*/compilation_key, + builder->getStringAttr(predicate_send_recv_key), + /*device_ordinal=*/builder->getI64IntegerAttr(0)); + + // Create host side if op. + return CloneEmptyIfWithPredicate(recv_predicate_at_host.getResult(0), + if_callsite_op.is_stateless(), + if_callsite_op.getLoc(), builder); +} + +// TODO(b/157054714): Use a better abstraction instead of +// _TPUCompileMlirOp and _XlaRecvAtHostOp and _XlaSendFromHostOp. +// Creates a compilation key as placeholder. A placeholder compilation cache key +// is created because it is a required input to _XlaRecvAtHost and +// _XlaSendFromHost but the _TPUCompileMlir has not yet been created for the TPU +// cluster that contains the outside compiled ops. This placeholder should be +// replaced by the TPU cluster _TPUCompileMlir in a subsequent pass. +Value CreateCompilationKeyPlaceholder(Location loc, OpBuilder* builder) { + auto result_type = + RankedTensorType::get({2}, builder->getType()); + return builder->create( + loc, /*program=*/result_type, llvm::ArrayRef{}); +} + +// Replicates the control flow operations that wraps outside compiled ops to +// `destination_block`. +Block* ReplicateControlFlowStack( + llvm::StringRef outside_cluster_name, + const llvm::SmallVectorImpl& stack_info, + tf_device::ClusterOp tpu_cluster, ModuleOp module, Value compilation_key, + Block* destination_block, int* send_recv_counter) { + assert(stack_info.size()); + OpBuilder builder = OpBuilder::atBlockTerminator(destination_block); + Operation* previous_replicated_controlflow_op = nullptr; + for (const auto& controlflow_stack_info : stack_info) { + // Create control flow op given provided insertion point and + // ControlFlowStackInfo. + previous_replicated_controlflow_op = + ReplicateIf(controlflow_stack_info, outside_cluster_name, module, + compilation_key, &builder, send_recv_counter); + auto if_op = llvm::cast(previous_replicated_controlflow_op); + auto type = controlflow_stack_info.GetBranchType(); + + // Update the insertion point to proper region inside the newly created + // control flow op. + if (type == ControlFlowStackInfo::kIfThen) { + builder.setInsertionPoint(&if_op.then_branch().front().front()); + } else { + builder.setInsertionPoint(&if_op.else_branch().front().front()); + } + } + + // Return the inner most branch at which outside compiled op is located. + // This block will later be used as insertion point to create send/recv ops. + auto inner_most_controlflow_stack = stack_info.back(); + auto inner_most_if = + llvm::cast(previous_replicated_controlflow_op); + if (inner_most_controlflow_stack.GetBranchType() == + ControlFlowStackInfo::kIfThen) { + return &inner_most_if.then_branch().front(); + } else { + return &inner_most_if.else_branch().front(); + } +} + +// Collects and clusters ops in `block` with the same `_xla_outside_compilation` +// attribute into `clusters` This returns an error if a +// `_xla_outside_compilation` attribute of an op is empty. +// TODO(b/163141763): Make sure ops inside control flow regions are not outside +// compiled if the entire control flow op is marked as outside compiled. +LogicalResult CollectAndGroupOutsideClusterOps(Block* block, + OutsideClusterMap* clusters) { + auto walk_result = block->walk([&](Operation* op) { + if (auto attr = op->getAttrOfType(kXlaOutsideCompilationAttr)) { + if (attr.getValue().empty()) { + op->emitError() << "attribute '" << kXlaOutsideCompilationAttr + << "' is empty"; + return WalkResult::interrupt(); + } + + auto it = clusters->try_emplace(attr.getValue()); + it.first->getSecond().push_back(op); + } + return WalkResult::advance(); + }); + + return failure(walk_result.wasInterrupted()); +} + +// Moves `cluster_ops` to associated `block`. +void MoveOutsideClusterOpsToBlock(Block& block, + llvm::ArrayRef cluster_ops, + MLIRContext* context) { + Operation* terminator = block.getTerminator(); for (Operation* cluster_op : cluster_ops) { // Remove `_xla_outside_compilation` and `device` attribute from ops in the // cluster as that information will be present in the `launch_op`. @@ -115,7 +294,7 @@ void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op, } } -// Creates a `tf_device::LaunchOp` to wrap cluster ops. +// Creates a `tf_device.launch` to wrap cluster ops. tf_device::LaunchOp CreateLaunchOpForOutsideCluster( OpBuilder* builder, Operation* last_cluster_op, llvm::StringRef host_device) { @@ -196,78 +375,91 @@ void SetHostComputeInsertion( // Creates the HostCompute with `inputs` and `outputs` // using `communication_key`. -TF::_HostComputeMlirOp CreateHostCompute( +TF::_XlaHostComputeMlirOp CreateHostCompute( OpBuilder* builder, tf_device::ClusterOp tpu_cluster, llvm::ArrayRef cluster_ops, const llvm::SmallSetVector& inputs, llvm::ArrayRef outputs, - llvm::StringRef communication_key) { + llvm::StringRef args_communication_key, + llvm::StringRef retvals_communication_key) { llvm::SmallVector device_output_types; for (const auto& output : outputs) device_output_types.push_back(output.getType()); SetHostComputeInsertion(builder, cluster_ops, inputs); - auto host_compute = builder->create( + auto host_compute = builder->create( tpu_cluster.getLoc(), device_output_types, inputs.getArrayRef(), - llvm::ArrayRef{}); - host_compute.setAttr(kAncestorsAttr, builder->getArrayAttr({})); - host_compute.setAttr(kShapesAttr, builder->getArrayAttr({})); - host_compute.setAttr(kKeyAttr, builder->getStringAttr(communication_key)); + builder->getStringAttr(args_communication_key), + builder->getStringAttr(retvals_communication_key), + /*tpu_core=*/builder->getI64IntegerAttr(0)); return host_compute; } void MoveOutsideCompiledOps( - tf_device::ClusterOp tpu_cluster, llvm::StringRef outside_cluster_name, - tf_device::LaunchOp host_launch_op, llvm::ArrayRef cluster_ops, + ModuleOp module, tf_device::ClusterOp tpu_cluster, + llvm::StringRef outside_cluster_name, tf_device::LaunchOp host_launch_op, + llvm::ArrayRef cluster_ops, const llvm::SmallSetVector& external_inputs, llvm::ArrayRef external_outputs) { + // Since ops in `cluster_ops` do not cross function/control flow boundary, it + // is sufficient to identify the control flow that wraps `cluster_ops` by + // looking at any arbitary op inside `cluster_ops`. + auto controlflow_stack = + GetControlFlowStackForOp(tpu_cluster, cluster_ops.front()); + + Value compilation_key; + if (!controlflow_stack.empty() || !external_inputs.empty() || + !external_outputs.empty()) { + OpBuilder builder(&host_launch_op.GetBody().front()); + compilation_key = + CreateCompilationKeyPlaceholder(tpu_cluster.getLoc(), &builder); + } + + Block* block_to_move_host_cluster = nullptr; + if (controlflow_stack.empty()) { + block_to_move_host_cluster = &host_launch_op.GetBody(); + } else { + int send_recv_counter = 0; + block_to_move_host_cluster = ReplicateControlFlowStack( + outside_cluster_name, controlflow_stack, tpu_cluster, module, + compilation_key, &host_launch_op.GetBody(), &send_recv_counter); + } + + MLIRContext* context = host_launch_op.getContext(); if (external_inputs.empty() && external_outputs.empty()) { - MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops); + MoveOutsideClusterOpsToBlock(*block_to_move_host_cluster, cluster_ops, + context); return; } - OpBuilder builder(host_launch_op.GetBody().getTerminator()); - auto result_type = - RankedTensorType::get({}, builder.getType()); - - std::string txt_metadata; - std::string txt_module; - // TODO(b/157054714): Use a better abstraction instead of _TPUCompileMlirOp - // and _XlaRecvAtHostOp and _XlaSendFromHostOp. - - // A placeholder _TpuCompileMlirOp is created because it is required input to - // XlaRecvAtHostOp and XlaSendFromHostOp but the _TpuCompileMlirOp has not yet - // been created for the TPU cluster that contains the outside compiled ops. - // This placeholder should be replaced by the TPU cluster _TPUCompileMlirOp in - // a subsequent pass. - auto compile_op = builder.create( - tpu_cluster.getLoc(), /*compilation_status=*/result_type, /*program=*/ - llvm::ArrayRef{result_type}, llvm::ArrayRef{}, txt_module, - txt_metadata); - + OpBuilder builder(block_to_move_host_cluster->getTerminator()); llvm::SmallVector host_output_types; for (const auto& external_input : external_inputs) host_output_types.push_back(external_input.getType()); - std::string communication_key = - llvm::formatv("host_compute_channel_{0}", outside_cluster_name).str(); - // XlaRecvAtHostOp takes both the program key(dynamic_key) from the - // _TpuCompileMlirOp and the communication_key. + std::string args_communication_key = + llvm::formatv("host_compute_channel_{0}_args", outside_cluster_name) + .str(); + std::string retvals_communication_key = + llvm::formatv("host_compute_channel_{0}_retvals", outside_cluster_name) + .str(); + auto recv_at_host = builder.create( tpu_cluster.getLoc(), host_output_types, - /*dynamic_key=*/compile_op.getResult(1), - builder.getStringAttr(communication_key), - builder.getIntegerAttr(builder.getIntegerType(64), 0)); + /*dynamic_key=*/compilation_key, + builder.getStringAttr(args_communication_key), + /*device_ordinal=*/builder.getI64IntegerAttr(0)); - auto host_compute = - CreateHostCompute(&builder, tpu_cluster, cluster_ops, external_inputs, - external_outputs, communication_key); - MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops); + auto host_compute = CreateHostCompute( + &builder, tpu_cluster, cluster_ops, external_inputs, external_outputs, + args_communication_key, retvals_communication_key); + MoveOutsideClusterOpsToBlock(*block_to_move_host_cluster, cluster_ops, + context); - builder.setInsertionPoint(host_launch_op.GetBody().getTerminator()); + builder.setInsertionPoint(block_to_move_host_cluster->getTerminator()); builder.create( tpu_cluster.getLoc(), external_outputs, - /*dynamic_key=*/compile_op.getResult(1), - builder.getStringAttr(communication_key), - /*device_ordinal=*/builder.getIntegerAttr(builder.getIntegerType(64), 0)); + /*dynamic_key=*/compilation_key, + builder.getStringAttr(retvals_communication_key), + /*device_ordinal=*/builder.getI64IntegerAttr(0)); for (auto result : llvm::zip(external_inputs, recv_at_host.getResults())) mlir::replaceAllUsesInRegionWith(std::get<0>(result), std::get<1>(result), @@ -280,7 +472,8 @@ void MoveOutsideCompiledOps( // Creates a `parallel_execute` op in place of launch with 'clusters` and // 'launch` as regions. -void CreateParallelExecuteFromOutsideClusters(tf_device::ClusterOp tpu_cluster, +void CreateParallelExecuteFromOutsideClusters(ModuleOp module, + tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters, llvm::StringRef host_device) { OpBuilder builder(tpu_cluster); @@ -296,6 +489,7 @@ void CreateParallelExecuteFromOutsideClusters(tf_device::ClusterOp tpu_cluster, Block& outside_block = parallel_execute_op.GetRegionBlockWithIndex(cluster.index()); + builder.setInsertionPointToEnd(&outside_block); tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster( &builder, cluster_ops.back(), host_device); @@ -304,10 +498,9 @@ void CreateParallelExecuteFromOutsideClusters(tf_device::ClusterOp tpu_cluster, auto external_inputs = GetExternalOperands(cluster_ops); auto external_outputs = GetExternalOutputs(cluster_ops); - MoveOutsideCompiledOps(tpu_cluster, cluster.value().getFirst(), + MoveOutsideCompiledOps(module, tpu_cluster, cluster.value().getFirst(), host_launch_op, cluster_ops, external_inputs, external_outputs); - builder.setInsertionPointToEnd(&outside_block); builder.create(tpu_cluster.getLoc(), ArrayRef{}); @@ -353,7 +546,8 @@ void TPUExtractOutsideCompilation::runOnOperation() { std::string host_device; tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster, &host_device); - CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters, + + CreateParallelExecuteFromOutsideClusters(module, tpu_cluster, clusters, host_device); return WalkResult::advance(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc index 3fd0dcd5a67..52c9287b619 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_merge_variables_with_execute.cc @@ -298,7 +298,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( // Populate infos.old_to_new_output_mapping. int new_output_index = 0; infos.old_to_new_output_mapping.resize(execute_launch.getNumResults()); - for (int i = 0; i < execute_launch.getNumResults(); ++i) { + for (int i = 0, end = execute_launch.getNumResults(); i < end; ++i) { if (output_fused[i]) { infos.old_to_new_output_mapping[i] = -1; } else { @@ -375,7 +375,7 @@ void ReplaceParallelExecute(tf_device::ParallelExecuteOp parallel_execute, // Replace the uses of the original parallel_execute for the region containing // the merged execute. auto old_region_results = parallel_execute.GetRegionOutputs(region_index); - for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) { + for (int i = 0, end = infos.old_to_new_output_mapping.size(); i < end; ++i) { if (infos.old_to_new_output_mapping[i] < 0) continue; old_region_results[i].replaceAllUsesWith(new_parallel_execute_op->getResult( infos.old_to_new_output_mapping[i] + num_results_before_region)); @@ -407,7 +407,7 @@ void ReplaceExecute(tf_device::LaunchOp execute_launch, tf_device::LaunchOp merged_execute_launch, const VariableAccessesForTPUExecute& infos) { // Replace the uses. - for (int i = 0; i < infos.old_to_new_output_mapping.size(); ++i) { + for (int i = 0, end = infos.old_to_new_output_mapping.size(); i < end; ++i) { if (infos.old_to_new_output_mapping[i] < 0) continue; execute_launch.getResult(i).replaceAllUsesWith( merged_execute_launch.getResult(infos.old_to_new_output_mapping[i])); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 050ba24417f..ca77feafc05 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -473,9 +473,8 @@ LogicalResult BuildExecuteOp( if (failed(result)) return failure(); // TPUExecute has same output types as cluster_func. - *execute_op = builder->create( - cluster_func.getLoc(), output_types, inputs, - llvm::ArrayRef{}); + *execute_op = builder->create(cluster_func.getLoc(), + output_types, inputs); return success(); } @@ -644,10 +643,7 @@ LogicalResult Rewrite( // Collect `num_replicas` and `num_cores_per_replica` attributes. int num_replicas = 1; tf_device::ReplicateOp replicate = - cluster_func.getParentOp() - ? llvm::dyn_cast_or_null( - cluster_func.getParentOp()) - : nullptr; + cluster_func.getParentOfType(); if (replicate) num_replicas = replicate.n().getLimitedValue(); auto num_cores_per_replica_attr = cluster_func.getAttrOfType( @@ -716,9 +712,9 @@ LogicalResult Rewrite( // structured lowering. if (auto parallel_op = llvm::dyn_cast( cluster_func.getParentOp())) { - parallel_op.walk([&](TF::_TPUCompileMlirOp parallel_compile_op) { - parallel_compile_op.replaceAllUsesWith(compile_op); - parallel_compile_op.erase(); + parallel_op.walk([&](TF::_TPUCompileMlirPlaceholderProgramKeyOp key_op) { + key_op.replaceAllUsesWith(compile_op->getResult(1)); + key_op.erase(); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc index 7befa68f3d8..204a674e632 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc @@ -604,8 +604,7 @@ void TPUSpaceToDepthPass::runOnOperation() { } // Get the function on device. - auto device_func = - getOperation().lookupSymbol(cluster_func->getFunc()); + auto device_func = cluster_func->getFunc(); if (!device_func) return; TF::Conv2DOp first_conv; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc index f3588c8359b..6cd9f763b87 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc @@ -13,24 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" namespace mlir { namespace TFTPU { namespace { +constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation"; constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer"; struct TPUUpdateEmbeddingEnqueueOpInputs @@ -44,8 +49,7 @@ struct TPUUpdateEmbeddingEnqueueOpInputs LogicalResult ExtractEmbeddingAttribute( Operation* op, llvm::StringMap* embedding_op_map) { auto embedding_attr = op->getAttrOfType(kTPUEmbeddingAttr); - if (!embedding_attr) - return op->emitOpError("requires attribute '_tpu_embedding_layer'"); + if (!embedding_attr) return mlir::success(); if (!embedding_op_map->insert({embedding_attr.getValue(), op}).second) return op->emitOpError( @@ -87,7 +91,8 @@ LogicalResult FindTPUEmbeddingOps( LogicalResult UpdateEmbeddingEnqueueOpInput( const llvm::StringMap& enqueue_op_map, const llvm::StringMap& recv_activation_op_map, - const llvm::StringMap& send_gradient_op_map) { + const llvm::StringMap& send_gradient_op_map, + OpBuilder* builder) { for (const auto& it : enqueue_op_map) { const auto& embedding_attr = it.getKey(); Operation* embedding_op = it.second; @@ -97,21 +102,36 @@ LogicalResult UpdateEmbeddingEnqueueOpInput( << TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op"; // TPU Embedding enqueue ops take different inputs depending on whether - // graph is in training mode or in eval/prediction mode. The inputs to the - // enqueue ops are present/listed as operands to SelectV2 op. Then branch - // operand of the SelectV2 op represents input to take during training - // and else branch operand represents input to take during - // prediction/evaluation. If SendTPUEmbeddingGradients op exists in the - // graph, then graph is in training mode, so correctly forward the input - // of SelectV2 op as operand to the TPU embedding enqueue op. + // graph is in training mode or in eval/prediction mode. During training, + // the mode parameter for TPUEmbeddingEnqueue op must be `train` and for + // evaluation or prediction, mode must be set to `inference`. + // If SendTPUEmbeddingGradients op exists in the graph, then graph is + // in training mode, so create a const op with value `train` use the + // output value of the constant as an operand to the TPU embedding + // enqueue op. bool is_training = send_gradient_op_map.count(embedding_attr); - for (auto enqueue_operand : embedding_op->getOperands()) { - if (auto select = llvm::dyn_cast_or_null( - enqueue_operand.getDefiningOp())) { - enqueue_operand.replaceAllUsesWith(is_training ? select.t() - : select.e()); - } - } + + // The last operand of TPUEmbeddingEnqueue ops is the mode which + // represents whether graph is in training mode or in evaluation mode. + auto& mode_enqueue_operand = + embedding_op->getOpOperand(embedding_op->getNumOperands() - 1); + + llvm::SmallVector mode_string_value; + mode_string_value.emplace_back(is_training ? "train" : "inference"); + builder->setInsertionPoint(embedding_op); + auto enqueue_mode = builder->create( + embedding_op->getLoc(), + DenseStringElementsAttr::get( + RankedTensorType::get({}, builder->getType()), + mode_string_value)); + + auto outside_compilation_attr = + embedding_op->getAttrOfType(kXlaOutsideCompilationAttr); + if (outside_compilation_attr) + enqueue_mode.setAttr(kXlaOutsideCompilationAttr, + outside_compilation_attr); + + mode_enqueue_operand.set(enqueue_mode); } return success(); @@ -141,8 +161,9 @@ void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() { return signalPassFailure(); } - if (failed(UpdateEmbeddingEnqueueOpInput( - enqueue_op_map, recv_activation_op_map, send_gradient_op_map))) + if (failed(UpdateEmbeddingEnqueueOpInput(enqueue_op_map, + recv_activation_op_map, + send_gradient_op_map, &builder))) return signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index 5bc6bd4e053..3262b83fc94 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -351,7 +351,7 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body, cond.setType(FunctionType::get(append_types(cond.getType().getInputs()), cond.getType().getResults(), cond.getContext())); - for (int64_t i = 0; i < state_vars.size(); ++i) { + for (int64_t i = 0, end = state_vars.size(); i < end; ++i) { int64_t arg_index = body.getNumArguments() - state_vars.size() + i; TF::VarHandleOp state_var = state_vars[i]; auto device_attr = state_var.getAttr(kDeviceAttr); @@ -365,16 +365,6 @@ TF::WhileOp AddStateVarsToWhileOp(TF::WhileOp while_op, FuncOp body, while_op.getLoc(), append_types(llvm::to_vector<4>(while_op.getResultTypes())), new_while_operands, while_op.getAttrs()); - if (new_while_op.output_shapes().size() != 0) { - auto new_output_shapes = llvm::to_vector<4>(new_while_op.output_shapes()); - // VarHandleOp is a scalar shape resource. - for (int64_t i = 0; i < state_vars.size(); ++i) { - new_output_shapes.push_back( - mlir::TF::ShapeAttr::get(builder.getContext(), ArrayRef())); - } - new_while_op.setAttr("output_shapes", - builder.getArrayAttr(new_output_shapes)); - } while_op.replaceAllUsesWith( new_while_op.getResults().take_front(while_op.getNumResults())); while_op.erase(); @@ -462,9 +452,8 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, !llvm::isa(compile_launch.GetBody().front())) return; - auto module = while_op.getParentOfType(); - auto body = llvm::cast(module.lookupSymbol(while_op.body())); - auto cond = llvm::cast(module.lookupSymbol(while_op.cond())); + FuncOp body = while_op.body_func(); + FuncOp cond = while_op.cond_func(); // Analyze the formattable inputs. auto execute_arg_to_outer_args = @@ -521,8 +510,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, replicate.GetNumReplicatedBlockArguments() - 1)); builder.setInsertionPoint(execute_launch); auto reformat_op = builder.create( - execute_launch.getLoc(), llvm::ArrayRef{}, reformat_operands, - llvm::ArrayRef{}); + execute_launch.getLoc(), llvm::ArrayRef{}, reformat_operands); WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op, execute_launch.device()); @@ -579,8 +567,7 @@ void HandleReplicateOp(TF::WhileOp while_op, tf_device::ReplicateOp replicate, default_state_key.getResult()); // Unformat op. auto unformat_op = builder.create( - while_op.getLoc(), llvm::ArrayRef{}, unformat_operands, - llvm::ArrayRef{}); + while_op.getLoc(), llvm::ArrayRef{}, unformat_operands); WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op, execute_launch.device()); builder.create(while_op.getLoc(), ArrayRef{}); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc index f09cf7b093e..0a69987deb0 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/breakup-islands.cc @@ -41,25 +41,28 @@ namespace mlir { namespace { -struct BreakUpIslands : PassWrapper { - void runOnFunction() final; +class BreakUpIslands : public TF::PerFunctionAggregateAnalysisConsumerPass< + BreakUpIslands, TF::SideEffectAnalysis> { + public: + void runOnFunction(FuncOp func, + const TF::SideEffectAnalysis::Info& side_effect_analysis); void BreakUpIsland(tf_executor::IslandOp island_op, - const TF::SideEffectAnalysis& side_effect_analysis, + const TF::SideEffectAnalysis::Info& side_effect_analysis, llvm::DenseMap>* new_control_inputs); }; -void BreakUpIslands::runOnFunction() { - auto graph_op_range = getFunction().getBody().front().without_terminator(); +void BreakUpIslands::runOnFunction( + FuncOp func, const TF::SideEffectAnalysis::Info& side_effect_analysis) { + auto graph_op_range = func.front().without_terminator(); tf_executor::GraphOp graph_op; - if (graph_op_range.begin() != graph_op_range.end() && - std::next(graph_op_range.begin()) == graph_op_range.end()) { - graph_op = dyn_cast( - getOperation().getBody().front().front()); - } + + if (llvm::hasSingleElement(graph_op_range)) + graph_op = dyn_cast(func.front().front()); + if (!graph_op) { - getOperation().emitError("expected function to contain only a graph_op"); + func.emitError("expected function to contain only a graph_op"); signalPassFailure(); return; } @@ -67,7 +70,6 @@ void BreakUpIslands::runOnFunction() { // New control inputs to be added. For an operation x, new_control_inputs[x] // contains all control inputs that need to be added to x as operands. llvm::DenseMap> new_control_inputs; - auto& side_effect_analysis = getAnalysis(); // Iterate in reverse order to avoid invalidating Operation* stored in // new_control_inputs. for (auto& item : @@ -76,7 +78,7 @@ void BreakUpIslands::runOnFunction() { BreakUpIsland(island, side_effect_analysis, &new_control_inputs); } } - OpBuilder builder(getOperation()); + OpBuilder builder(func); // For every op, add new control inputs in reverse order so that the ops don't // get invalidated. @@ -181,7 +183,7 @@ struct IslandSourcesAndSinks { // Finds IslandSourcesAndSinks for an unmodified island. IslandSourcesAndSinks FindSourcesAndSinksInIsland( tf_executor::IslandOp island, - const TF::SideEffectAnalysis& side_effect_analysis) { + const TF::SideEffectAnalysis::Info& side_effect_analysis) { IslandSourcesAndSinks result; auto island_body = island.GetBody().without_terminator(); for (Operation& sub_op : island_body) { @@ -208,7 +210,7 @@ IslandSourcesAndSinks FindSourcesAndSinksInIsland( // are chained together by control flow values. void BreakUpIslands::BreakUpIsland( tf_executor::IslandOp island_op, - const TF::SideEffectAnalysis& side_effect_analysis, + const TF::SideEffectAnalysis::Info& side_effect_analysis, llvm::DenseMap>* new_control_inputs) { auto island_body = island_op.GetBody().without_terminator(); @@ -323,7 +325,7 @@ void BreakUpIslands::BreakUpIsland( } // namespace -std::unique_ptr> CreateBreakUpIslandsPass() { +std::unique_ptr> CreateBreakUpIslandsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 7983dfe0065..571d5e3e715 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -511,17 +511,19 @@ StatusOr> Exporter::Convert( // generate unique names. if (!output_names.empty()) { const int num_data_results = graph_op.getNumResults(); - TF_RET_CHECK(output_names.size() == num_data_results) + const int64 output_names_size = output_names.size(); + TF_RET_CHECK(output_names_size == num_data_results) << "output names (" << output_names.size() << ") != terminator operands (" << num_data_results << ")"; llvm::DenseMap output_op_to_name; llvm::StringMap name_to_op; for (const auto& it : llvm::enumerate(graph_op.GetFetch().getOperands())) { // Skip control rets. - if (it.index() >= num_data_results) break; + const int64 index = it.index(); + if (index >= num_data_results) break; // TODO(jpienaar): If there is a result index specified, ensure only one // and that it matches the result index of the op. - std::string orig_name(output_names[it.index()]); + std::string orig_name(output_names[index]); auto tensor_id = ParseTensorName(orig_name); auto name = LegalizeNodeName( llvm::StringRef(tensor_id.node().data(), tensor_id.node().size())); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index c7d5339f93c..94ddf76736e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -119,7 +119,6 @@ static inline absl::string_view StringRefToView(llvm::StringRef ref) { namespace tensorflow { using mlir::NamedAttrList; using mlir::TensorType; -using mlir::TF::VarHandleOp; using mlir::tf_saved_model::AssetOp; using mlir::tf_saved_model::GlobalTensorOp; using mlir::tf_saved_model::SessionInitializerOp; @@ -129,12 +128,6 @@ namespace { constexpr char kTpuReplicateAttr[] = "_tpu_replicate"; -bool IsDisableCallShapeInferenceAttribute(const AttrValue& attr_value, - llvm::StringRef attr_name) { - return attr_name.compare("_disable_call_shape_inference") == 0 && - attr_value.value_case() == AttrValue::kB; -} - bool IsOutputShapesAttribute(const AttrValue& attr_value, llvm::StringRef attr_name) { return attr_name.compare("_output_shapes") == 0 && @@ -336,14 +329,11 @@ class ImporterBase { NamedAttrList* attributes); // Helper to create either a tf_executor operation or a TF operation wrapped - // in an island. When convert_to_legacy_call is true, converts the operation - // representing a call to a library function with a name represented in - // node_type_name to LegacyCallOp. + // in an island. mlir::Operation* CreateOperation( const Node& node, llvm::StringRef node_type_name, const mlir::OperationState& result, - const llvm::SmallVectorImpl& control_operands, - bool convert_to_legacy_call = false); + const llvm::SmallVectorImpl& control_operands); // Converts one NodeDef from the input GraphDef into an Operation and // inserts it into the MLIR module using builder_. @@ -1680,8 +1670,7 @@ Status ImporterBase::EmitErrorWithLocationStr(const Node& node, mlir::Operation* ImporterBase::CreateOperation( const Node& node, llvm::StringRef node_type_name, const mlir::OperationState& result, - const llvm::SmallVectorImpl& control_operands, - bool convert_to_legacy_call) { + const llvm::SmallVectorImpl& control_operands) { // For the tf.executor specific operations (not wrapped in an island), we // have an extra returned value for the control result, and we concatenate // control and non-control operands. @@ -1744,25 +1733,7 @@ mlir::Operation* ImporterBase::CreateOperation( mlir::OpBuilder::atBlockEnd(&island.GetBody()); // Create the operation inside the island now. - mlir::Operation* inner_op; - if (convert_to_legacy_call) { - bool disable_call_shape_inference = false; - for (const auto& name_and_value : node.attrs()) { - const auto& attr_name = name_and_value.first; - const AttrValue& attr_value = name_and_value.second; - if (IsDisableCallShapeInferenceAttribute(attr_value, attr_name)) { - disable_call_shape_inference = attr_value.b(); - } - } - - mlir::BoolAttr attribute = - builder_.getBoolAttr(disable_call_shape_inference); - inner_op = island_builder.create( - result.location, result.types, result.operands, - island_builder.getSymbolRefAttr(node_type_name), attribute); - } else { - inner_op = island_builder.createOperation(result); - } + mlir::Operation* inner_op = island_builder.createOperation(result); // Sets operand_segment_sizes or result_segment_sizes attribute to the op. const auto set_segment_sizes_attr = @@ -1927,13 +1898,6 @@ Status ImporterBase::ConvertNode(const Node& node) { // Remove _output_shapes attribute that will be added by the exporter. if (IsOutputShapesAttribute(attr_value, attr_name)) continue; - // We represent the _diable_call_shape_inference attribute and remove - // the _output_shapes attribute for LegacyCall. If a call has other - // attributes, we can't convert it to LegacyCall. - if (convert_to_legacy_call && - !IsDisableCallShapeInferenceAttribute(attr_value, attr_name)) { - convert_to_legacy_call = false; - } if (attr_value.value_case() == AttrValue::kFunc) { // Attribute iteration order is not defined for protocol buffer Map. // Process function attributes separately in the lexicographical order to @@ -1957,26 +1921,35 @@ Status ImporterBase::ConvertNode(const Node& node) { result.attributes.push_back(builder_.getNamedAttr( "device", builder_.getStringAttr(std::string(node_def.device())))); - // Map If and StatelessIf op in TensorFlow to the common If op in MLIR and add - // the differentiating attribute. - if (node.IsIfNode()) { - result.name = mlir::OperationName(get_full_op_name("If"), context_); - mlir::BoolAttr val = builder_.getBoolAttr(node_type_name == "StatelessIf"); - result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); + // Map user function calls to LegacyCall ops and add the user function name + // as an attribute. + if (convert_to_legacy_call) { + result.name = mlir::OperationName(get_full_op_name("LegacyCall"), context_); + mlir::SymbolRefAttr val = builder_.getSymbolRefAttr(node_type_name); + result.addAttribute("f", val); + + if (!result.attributes.get("_disable_call_shape_inference")) { + result.addAttribute("_disable_call_shape_inference", + builder_.getBoolAttr(false)); + } } - // Map While and StatelessWhile op in TensorFlow to the common While op in - // MLIR and add the differentiating attribute. - if (node.IsWhileNode()) { - result.name = mlir::OperationName(get_full_op_name("While"), context_); - mlir::BoolAttr val = - builder_.getBoolAttr(node_type_name == "StatelessWhile"); + auto composite_control_flow_op = [&](const std::string& name) { + result.name = mlir::OperationName(get_full_op_name(name), context_); + bool stateless = absl::StartsWith(node_type_name, "Stateless"); + mlir::BoolAttr val = builder_.getBoolAttr(stateless); result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); - } + }; + + // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common + // Case/If/While op in MLIR and add the differentiating attribute. + if (node.IsCaseNode()) composite_control_flow_op("Case"); + if (node.IsIfNode()) composite_control_flow_op("If"); + if (node.IsWhileNode()) composite_control_flow_op("While"); // Register the mapping between the TF node and the newly created operation. - node_values_[node.id()] = CreateOperation( - node, node_type_name, result, control_operands, convert_to_legacy_call); + node_values_[node.id()] = + CreateOperation(node, node_type_name, result, control_operands); return Status::OK(); } @@ -2387,7 +2360,8 @@ GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( "' is missing attribute 'index'"); auto index = attr->i(); - if (nodes->size() < index + 1) nodes->resize(index + 1); + const int num_nodes = nodes->size(); + if (num_nodes < index + 1) nodes->resize(index + 1); if ((*nodes)[index].node != nullptr) return errors::InvalidArgument(node->type_string(), " node '", @@ -2895,7 +2869,7 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) { mlir::OpBuilder builder(func.getBody()); llvm::SmallVector new_input_types; for (int i = 0, e = func.getNumArguments(); i < e; i++) { - auto arg = func.front().getArgument(i); + auto arg = func.getArgument(i); auto global_tensor = mlir::tf_saved_model::LookupBoundInputOfType< mlir::tf_saved_model::GlobalTensorOp>(func, i, symbol_table); if (global_tensor) { @@ -3108,7 +3082,8 @@ Status CreateSavedModelIR( TF_ASSIGN_OR_RETURN(auto input_index_paths, input_linearizer.GetLeafIndexPaths( error_context + "in input signature: ")); - if (bound_input_base != input_index_paths.size()) { + const int input_index_paths_size = input_index_paths.size(); + if (bound_input_base != input_index_paths_size) { return errors::InvalidArgument( error_context, "Argument mismatch between concrete function input signature " @@ -3389,12 +3364,13 @@ SavedModelSignatureDefImporter::ConvertAssets() { results.reserve(asset_file_defs.size()); mlir::OpBuilder builder(module_->getBodyRegion()); + unsigned i = 0; // Use to generate unique sym_name(s) for duplicate assets. for (const auto& asset : asset_file_defs) { auto asset_op = builder.create( module_->getLoc(), /*sym_name=*/ builder.getStringAttr( - absl::StrCat("__tf_saved_model_asset_", asset.filename())), + absl::StrCat("__tf_saved_model_asset", i++, "_", asset.filename())), /*filename=*/ builder.getStringAttr( io::JoinPath(kSavedModelAssetsDirectory, asset.filename()))); @@ -3590,9 +3566,9 @@ Status SavedModelSignatureDefImporter::LiftVariables() { pm.addPass(mlir::TF::CreatePromoteVarHandlesToArgsPass()); pm.addPass( mlir::tf_saved_model::CreateLiftVariablesPass(bundle_.GetSession())); + pm.addPass(mlir::tf_saved_model::CreateDedupBoundInputBindingPass()); if (mlir::failed(pm.run(*module_))) - return diag_handler.Combine( - errors::Internal("failed to lifting variables.")); + return diag_handler.Combine(errors::Internal("Failed to lift variables.")); return Status::OK(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc index 77da19d6853..f6d370ca604 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.cc @@ -28,14 +28,15 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { Status ParseOutputArrayInfo(absl::string_view array_names, std::vector* outputs) { - std::vector output_names = absl::StrSplit(array_names, ','); - return ParseOutputArrayInfo(output_names, outputs); + TF_RETURN_IF_ERROR(ParseNodeNames(array_names, *outputs)); + return Status::OK(); } Status ParseOutputArrayInfo(const std::vector& output_names, @@ -51,22 +52,12 @@ Status ParseInputArrayInfo(absl::string_view array_names, absl::string_view data_types, absl::string_view shapes, GraphImportConfig::InputArrays* inputs) { - std::vector node_names = absl::StrSplit(array_names, ','); - std::vector node_dtypes = absl::StrSplit(data_types, ','); - - std::vector node_shapes_str = absl::StrSplit(shapes, ':'); + std::vector node_names; + std::vector node_dtypes; std::vector> node_shapes; - for (int i = 0; i < node_shapes_str.size(); i++) { - std::vector dims; - for (auto& dim_str : absl::StrSplit(node_shapes_str[i], ',')) { - // Treats empty input shape as scalar - if (dim_str.empty()) continue; - int size; - TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size)); - dims.push_back(size); - } - node_shapes.push_back(dims); - } + TF_RETURN_IF_ERROR(ParseNodeNames(array_names, node_names)); + TF_RETURN_IF_ERROR(ParseNodeDataTypes(data_types, node_dtypes)); + TF_RETURN_IF_ERROR(ParseNodeShapes(shapes, node_shapes)); return ParseInputArrayInfo(node_names, node_dtypes, node_shapes, inputs); } @@ -75,8 +66,7 @@ Status ParseInputArrayInfo(const std::vector& node_names, const std::vector>& node_shapes, GraphImportConfig::InputArrays* inputs) { std::vector used_node_dtypes; - if (node_dtypes.empty() || - (node_dtypes.size() == 1 && node_dtypes[0].empty())) { + if (node_dtypes.empty()) { // Mark all the node dtypes Invalid, so the importer can handle them by // using the type from the graph. used_node_dtypes.resize(node_names.size(), DataType_Name(DT_INVALID)); @@ -97,14 +87,14 @@ Status ParseInputArrayInfo(const std::vector& node_names, node_names.size(), ", #data_types ", node_dtypes.size(), ")")); } - if (node_names.size() != node_shapes.size()) { + if (!node_shapes.empty() && node_names.size() != node_shapes.size()) { return errors::FailedPrecondition(absl::StrCat( - "Unmatched node array and data type numbers (#arrays ", - node_names.size(), ", #input_shapes ", node_shapes.size(), ")")); + "Unmatched node array and shape numbers (#arrays ", node_names.size(), + ", #input_shapes ", node_shapes.size(), ")")); } // StringMap doesn't support reserve else reserve input map size here. - for (int i = 0; i < node_names.size(); i++) { + for (int i = 0, end = node_names.size(); i < end; i++) { auto& name = node_names[i]; if (name.empty()) continue; @@ -119,11 +109,49 @@ Status ParseInputArrayInfo(const std::vector& node_names, absl::StrCat("Invalid node type '", node_dtypes[i], "'")); } - for (auto& dim : node_shapes[i]) { - info.shape.add_dim()->set_size(dim); + if (!node_shapes.empty()) { + for (auto& dim : node_shapes[i]) { + info.shape.add_dim()->set_size(dim); + } } } return Status::OK(); } +Status ParseNodeShapes(absl::string_view shapes_str, + std::vector>& shapes_vector) { + shapes_vector.clear(); + if (!shapes_str.empty()) { + std::vector node_shapes_str = absl::StrSplit(shapes_str, ':'); + for (int i = 0; i < node_shapes_str.size(); i++) { + std::vector dims; + for (const absl::string_view dim_str : + absl::StrSplit(node_shapes_str[i], ',')) { + // Treats empty input shape as scalar + if (dim_str.empty()) continue; + int size; + TF_RET_CHECK(absl::SimpleAtoi(dim_str, &size)); + dims.push_back(size); + } + shapes_vector.push_back(dims); + } + } + return Status::OK(); +} + +Status ParseNodeNames(absl::string_view names_str, + std::vector& names_vector) { + names_vector = absl::StrSplit(names_str, ',', absl::SkipEmpty()); + return Status::OK(); +} + +Status ParseNodeDataTypes(absl::string_view data_types_str, + std::vector& data_type_vector) { + data_type_vector.clear(); + if (!data_types_str.empty()) { + data_type_vector = absl::StrSplit(data_types_str, ','); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index cc38a73d106..334f935a139 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -96,6 +96,23 @@ Status ParseInputArrayInfo(const std::vector& node_names, const std::vector& node_dtypes, const std::vector>& node_shapes, GraphImportConfig::InputArrays* inputs); + +// Parses shapes from the given string into shapes_vector which is a structured +// format. +// NOTE: If shapes_str is empty, shapes_vector will also be empty. +Status ParseNodeShapes(absl::string_view shapes_str, + std::vector>& shapes_vector); + +// Parses names from the given string into the names_vector. +// NOTE: If names_str is empty, names_vector will also be empty. +Status ParseNodeNames(absl::string_view names_str, + std::vector& names_vector); + +// Parses data types from the given string into the data_type_vector. +// NOTE: If data_types_str is empty, data_type_vector will also be empty. +Status ParseNodeDataTypes(absl::string_view data_types_str, + std::vector& data_type_vector); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_ROUNDTRIP_FLAGS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index b782b2c49d9..1c7988d3a40 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/utils/transitive_fanin.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" @@ -42,11 +43,14 @@ namespace tensorflow { static StatusOr GraphdefToMlirImport( llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view input_arrays, absl::string_view input_dtypes, - absl::string_view input_shapes, absl::string_view output_arrays, - absl::string_view control_output_arrays, bool prune_unused_nodes, - bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, - bool enable_shape_inference, mlir::MLIRContext* context) { + const std::vector& input_arrays, + const std::vector& input_dtypes, + const std::vector>& input_shapes, + const std::vector& output_arrays, + const std::vector& control_output_arrays, + bool prune_unused_nodes, bool convert_legacy_fed_inputs, + bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, + mlir::MLIRContext* context) { GraphDef graphdef; TF_RETURN_IF_ERROR( tensorflow::LoadProtoFromBuffer({input.data(), input.size()}, &graphdef)); @@ -97,11 +101,14 @@ static StatusOr GraphdefToMlirImport( StatusOr GraphdefToMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view input_arrays, absl::string_view input_dtypes, - absl::string_view input_shapes, absl::string_view output_arrays, - absl::string_view control_output_arrays, bool prune_unused_nodes, - bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, - bool enable_shape_inference, mlir::MLIRContext* context) { + const std::vector& input_arrays, + const std::vector& input_dtypes, + const std::vector>& input_shapes, + const std::vector& output_arrays, + const std::vector& control_output_arrays, + bool prune_unused_nodes, bool convert_legacy_fed_inputs, + bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, + mlir::MLIRContext* context) { auto module_or = GraphdefToMlirImport( input, debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, control_output_arrays, prune_unused_nodes, @@ -113,6 +120,31 @@ StatusOr GraphdefToMlirTranslateFunction( return module_or; } +StatusOr GraphdefToMlirTranslateFunction( + llvm::StringRef input, absl::string_view debug_info_file, + absl::string_view input_arrays, absl::string_view input_dtypes, + absl::string_view input_shapes, absl::string_view output_arrays, + absl::string_view control_output_arrays, bool prune_unused_nodes, + bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, + bool enable_shape_inference, mlir::MLIRContext* context) { + std::vector input_array_vector; + std::vector input_dtype_vector; + std::vector> input_shapes_vector; + std::vector output_array_vector; + std::vector control_output_array_vector; + TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector)); + TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector)); + TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector)); + TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector)); + TF_RETURN_IF_ERROR( + ParseNodeNames(control_output_arrays, control_output_array_vector)); + return GraphdefToMlirTranslateFunction( + input, debug_info_file, input_array_vector, input_dtype_vector, + input_shapes_vector, output_array_vector, control_output_array_vector, + prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, + upgrade_legacy, enable_shape_inference, context); +} + StatusOr SavedModelObjectGraphToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, @@ -161,11 +193,14 @@ StatusOr SavedModelSignatureDefsToMlirImport( StatusOr GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, - absl::string_view input_arrays, absl::string_view input_dtypes, - absl::string_view input_shapes, absl::string_view output_arrays, - absl::string_view control_output_arrays, bool prune_unused_nodes, - bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, - bool enable_shape_inference, mlir::MLIRContext* context) { + const std::vector& input_arrays, + const std::vector& input_dtypes, + const std::vector>& input_shapes, + const std::vector& output_arrays, + const std::vector& control_output_arrays, + bool prune_unused_nodes, bool convert_legacy_fed_inputs, + bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, + mlir::MLIRContext* context) { auto module_or = GraphdefToMlirImport( input, debug_info_file, input_arrays, input_dtypes, input_shapes, output_arrays, control_output_arrays, prune_unused_nodes, @@ -211,4 +246,29 @@ StatusOr GraphdefToSplattedMlirTranslateFunction( return module_or; } +StatusOr GraphdefToSplattedMlirTranslateFunction( + llvm::StringRef input, absl::string_view debug_info_file, + absl::string_view input_arrays, absl::string_view input_dtypes, + absl::string_view input_shapes, absl::string_view output_arrays, + absl::string_view control_output_arrays, bool prune_unused_nodes, + bool convert_legacy_fed_inputs, bool graph_as_function, bool upgrade_legacy, + bool enable_shape_inference, mlir::MLIRContext* context) { + std::vector input_array_vector; + std::vector input_dtype_vector; + std::vector> input_shapes_vector; + std::vector output_array_vector; + std::vector control_output_array_vector; + TF_RETURN_IF_ERROR(ParseNodeNames(input_arrays, input_array_vector)); + TF_RETURN_IF_ERROR(ParseNodeDataTypes(input_dtypes, input_dtype_vector)); + TF_RETURN_IF_ERROR(ParseNodeNames(output_arrays, output_array_vector)); + TF_RETURN_IF_ERROR(ParseNodeShapes(input_shapes, input_shapes_vector)); + TF_RETURN_IF_ERROR( + ParseNodeNames(control_output_arrays, control_output_array_vector)); + return GraphdefToSplattedMlirTranslateFunction( + input, debug_info_file, input_array_vector, input_dtype_vector, + input_shapes_vector, output_array_vector, control_output_array_vector, + prune_unused_nodes, convert_legacy_fed_inputs, graph_as_function, + upgrade_legacy, enable_shape_inference, context); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index ff5dc287488..0dc49d70192 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/base/macros.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -33,9 +34,25 @@ using stream_executor::port::StatusOr; // TODO(antiagainst): Directly manipulating files in library functions is not // a good idea. We should pass in a string/stream here. -// Converts a TensorFlow GraphDef stored in the file with the given -// `input_filename` into a MLIR module. Creates MLIR entities into the -// given MLIR `context`. +// Converts a TensorFlow GraphDef contained in `input` param into a MLIR module. +// Creates MLIR entities into the given MLIR `context`. +StatusOr GraphdefToMlirTranslateFunction( + llvm::StringRef input, absl::string_view debug_info_file, + const std::vector& input_arrays, + const std::vector& input_dtypes, + const std::vector>& input_shapes, + const std::vector& output_arrays, + const std::vector& control_output_arrays, + bool prune_unused_nodes, bool convert_legacy_fed_inputs, + bool graph_as_function, bool upgrade_legacy, + // TODO(jpienaar): Remove this. + bool enable_shape_inference, mlir::MLIRContext* context); + +ABSL_DEPRECATED( + "Please use the other overload of this function which accepts structured " + "inputs instead of strings") +// Converts a TensorFlow GraphDef contained in `input` param into a MLIR module. +// Creates MLIR entities into the given MLIR `context`. StatusOr GraphdefToMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, absl::string_view input_arrays, absl::string_view input_dtypes, @@ -47,6 +64,22 @@ StatusOr GraphdefToMlirTranslateFunction( // Similar as the above function, but replaces all constant tensors // with randomly generated splat values. +StatusOr GraphdefToSplattedMlirTranslateFunction( + llvm::StringRef input, absl::string_view debug_info_file, + const std::vector& input_arrays, + const std::vector& input_dtypes, + const std::vector>& input_shapes, + const std::vector& output_arrays, + const std::vector& control_output_arrays, + bool prune_unused_nodes, bool convert_legacy_fed_inputs, + bool graph_as_function, bool upgrade_legacy, bool enable_shape_inference, + mlir::MLIRContext* context); + +ABSL_DEPRECATED( + "Please use the other overload of this function which accepts structured " + "inputs instead of strings") +// Similar as the above function, but replaces all constant tensors +// with randomly generated splat values. StatusOr GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, absl::string_view debug_info_file, absl::string_view input_arrays, absl::string_view input_dtypes, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 5e548da55f1..eee2f0a560c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -17,11 +17,14 @@ limitations under the License. #include "absl/types/optional.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -36,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" @@ -52,6 +56,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -79,11 +84,17 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string, return Status::OK(); } +// Arguments to a computation can be either a tensor or resource. +struct TensorOrResourceShape { + TensorShape shape; + bool is_resource = false; +}; + // Converts arg_shapes to xla::Shape's and store into xla_input_shapes. Status GetXlaInputShapes( - mlir::ModuleOp module, llvm::ArrayRef arg_shapes, + mlir::ModuleOp module, llvm::ArrayRef arg_shapes, bool use_tuple_args, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, std::vector* xla_input_shapes) { xla_input_shapes->clear(); @@ -103,7 +114,7 @@ Status GetXlaInputShapes( DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype)); TF_ASSIGN_OR_RETURN(xla_shape, - shape_representation_fn(arg_shapes[i], dtype, + shape_representation_fn(arg_shapes[i].shape, dtype, /*use_fast_memory=*/false)); // Rewrite layout with sharding, if sharding is set. @@ -132,12 +143,13 @@ Status GetXlaInputShapes( } // Calculates computation output shape and build OutputDescription for each -// output based on static shapes in MLIR module +// output based on static shapes in MLIR module. If an output is a resource +// write, `resource_updates` is populated insead of `outputs` for that output. Status GetOutputInfo( mlir::ModuleOp module, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - xla::Shape* xla_output_shape, - std::vector* outputs) { + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_output_shape, std::vector* outputs, + std::vector* resource_updates) { auto shape_representation_fn_no_fast_memory = [shape_representation_fn](const TensorShape& shape, DataType dtype) { return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); @@ -148,20 +160,40 @@ Status GetOutputInfo( outputs->clear(); outputs->reserve(func_type.getNumResults()); + resource_updates->reserve(func_type.getNumResults()); std::vector shapes; shapes.reserve(func_type.getNumResults()); - for (mlir::Type type : func_type.getResults()) { + llvm::SmallDenseMap resource_arg_to_write; + for (unsigned i = 0; i < main_func.getNumArguments(); ++i) + if (auto aliasing_output = main_func.getArgAttrOfType( + i, "tf.aliasing_output")) + resource_arg_to_write.insert({aliasing_output.getInt(), i}); + + for (auto type_and_idx : llvm::enumerate(func_type.getResults())) { TF_ASSIGN_OR_RETURN( xla::Shape shape, - xla::TypeToShape(type, shape_representation_fn_no_fast_memory)); - auto tensor_type = type.dyn_cast(); + xla::TypeToShape(type_and_idx.value(), + shape_representation_fn_no_fast_memory)); + auto tensor_type = type_and_idx.value().dyn_cast(); shapes.push_back(shape); + auto it = resource_arg_to_write.find(type_and_idx.index()); + if (it != resource_arg_to_write.end()) { + // Add resource write. + resource_updates->emplace_back(); + XlaResourceUpdate& resource_update = resource_updates->back(); + resource_update.input_index = it->getSecond(); + resource_update.modified = true; + TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &resource_update.type)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &resource_update.shape)); + continue; + } + // Construct OutputDescription for result. outputs->emplace_back(); - XlaCompiler::OutputDescription& out_desc = outputs->back(); + XlaOutputDescription& out_desc = outputs->back(); TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type)); // TODO(ycao): Support constant output. out_desc.is_constant = false; @@ -181,14 +213,6 @@ Status GetOutputInfo( return Status::OK(); } -// Gets information about how computation updates Tensorflow resources. -// TODO(ycao): Implement logic to compute resource updates when we need to -// support graphs with resource updates in MLIR-based TF compiler bridge. -void GetResourceUpdatesForMlir( - std::vector* resource_updates) { - resource_updates->clear(); -} - // Creates a vector that maps from the parameters of the XLA computation to // their original argument positions. // MLIR-based TF-Compiler bridge doesn't have constant analysis yet, thus no @@ -202,7 +226,7 @@ void GetInputMappingForMlir(int num_inputs, std::vector* input_mapping) { } // Refine MLIR types based on new shape information. -Status RefineShapes(llvm::ArrayRef arg_shapes, +Status RefineShapes(llvm::ArrayRef arg_shapes, mlir::ModuleOp module) { auto producer_or = GetTfGraphProducerVersion(module); if (!producer_or.ok()) return producer_or.status(); @@ -213,15 +237,20 @@ Status RefineShapes(llvm::ArrayRef arg_shapes, { // Convert arg_shapes to a mlir friendly format. size_t count = 0; - for (const TensorShape& shape : arg_shapes) { - count += shape.dims(); + for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) { + if (tensor_resource_shape.is_resource) continue; + count += tensor_resource_shape.shape.dims(); } shape_backing.resize(count); arg_shapes_copy.reserve(arg_shapes.size()); size_t offset = 0; - for (const TensorShape& shape : arg_shapes) { + for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) { + if (tensor_resource_shape.is_resource) { + arg_shapes_copy.push_back(llvm::ArrayRef()); + continue; + } size_t start = offset; - for (tensorflow::TensorShapeDim dim : shape) { + for (tensorflow::TensorShapeDim dim : tensor_resource_shape.shape) { shape_backing[offset] = dim.size; ++offset; } @@ -265,7 +294,7 @@ Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, bool return_tuple, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, std::vector> custom_legalization_passes) { mlir::PassManager tf2xla(module_op.getContext()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); @@ -275,35 +304,33 @@ Status ConvertMLIRToXlaComputation( tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass()); tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass()); tf2xla.addPass(mlir::createSymbolDCEPass()); + // Guarantee all functions have one use, which enables shape inference. + tf2xla.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass()); tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); // LegalizeTFControlFlow encapsulates arguments for control flow operations // with a tuple argument which break the assumption of resource lifting // inside PromoteResourcesToArgs. tf2xla.addPass(mlir::mhlo::createLegalizeTFControlFlowPass()); - tf2xla.addNestedPass(mlir::mhlo::createLegalizeTFPass(true)); + tf2xla.addNestedPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/true, /*legalize_chlo=*/true, + /*tf2xla_fallback_device_type=*/device_type)); for (auto& target_pass : custom_legalization_passes) { tf2xla.addNestedPass(std::move(target_pass)); } tf2xla.addNestedPass(mlir::createCanonicalizerPass()); - tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); - - // Leverage tf2xla kernels for ops that didn't get lowered in the previous - // legalization pass. - tf2xla.addPass(mlir::mhlo::createLegalizeTfWithTf2XlaPass(device_type)); - tf2xla.addNestedPass(mlir::createCanonicalizerPass()); - // Run shape inference pass to propagate shapes through tensor_cast operations // from static to dynamic shapes. This could be generated if the shape // inference was originally missing in a TF op but the corresponding HLO op // had static shape after lowering. tf2xla.addPass(mlir::TF::CreateTFShapeInferencePass()); - // Run LegalizeTFPass again because the previous legalization passes can // expose more graph pruning and canonicalization opportunities that are // necessary for the second LegalizeTFPass(allow_partial_conversion=false) // invocation. - tf2xla.addNestedPass(mlir::mhlo::createLegalizeTFPass(false)); + tf2xla.addNestedPass(mlir::mhlo::createLegalizeTFPass( + /*allow_partial_conversion=*/false, /*legalize_chlo=*/true, + /*tf2xla_fallback_device_type=*/device_type)); // In order to export to XLA, we must sink constants to control flow regions, // since XLA uses functional control flow. tf2xla.addNestedPass( @@ -339,10 +366,10 @@ Status ConvertMLIRToXlaComputation( } static Status CompileMlirToXlaHlo( - mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, + mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, std::vector> custom_legalization_passes) { if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -373,14 +400,10 @@ static Status CompileMlirToXlaHlo( shape_representation_fn, &compilation_result->xla_input_shapes)); - // Compute all output descriptions. - TF_RETURN_IF_ERROR(GetOutputInfo(module_op, shape_representation_fn, - &compilation_result->xla_output_shape, - &compilation_result->outputs)); - - // Compute what resource variables need to be updated after XlaComputation's - // execution. - GetResourceUpdatesForMlir(&compilation_result->resource_updates); + // Compute all output descriptions and resource writes + TF_RETURN_IF_ERROR(GetOutputInfo( + module_op, shape_representation_fn, &compilation_result->xla_output_shape, + &compilation_result->outputs, &compilation_result->resource_updates)); if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_after", module_op); @@ -391,8 +414,8 @@ static Status CompileMlirToXlaHlo( Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, std::vector> custom_legalization_passes) { RegisterDialects(); mlir::MLIRContext mlir_context; @@ -400,27 +423,52 @@ Status CompileSerializedMlirToXlaHlo( TF_RETURN_IF_ERROR( ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); - return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, device_type, - use_tuple_args, shape_representation_fn, - compilation_result, + llvm::SmallVector tensor_or_resource_shapes; + tensor_or_resource_shapes.reserve(arg_shapes.size()); + for (const auto& arg_shape : arg_shapes) + tensor_or_resource_shapes.push_back({arg_shape}); + return CompileMlirToXlaHlo(mlir_module.get(), tensor_or_resource_shapes, + device_type, use_tuple_args, + shape_representation_fn, compilation_result, std::move(custom_legalization_passes)); } // Rewrites the given module with specified args. For each of the constant args, // it gets inlined in the "main' function and the corresponding argument is -// removed from the signature. +// removed from the signature. For resource args, their subtypes are populated. // Returns the original indices for the other arguments on success. static StatusOr> RewriteWithArgs( - mlir::ModuleOp module, llvm::ArrayRef args) { + mlir::ModuleOp module, llvm::ArrayRef args) { mlir::FuncOp main_fn = module.lookupSymbol("main"); std::vector params; + bool has_resource_args = false; auto builder = mlir::OpBuilder(main_fn.getBody()); std::vector args_to_erase; for (int idx = 0; idx < args.size(); idx++) { - const XlaCompiler::Argument& xla_arg = args[idx]; + const XlaArgument& xla_arg = args[idx]; mlir::BlockArgument mlir_arg = main_fn.getArgument(idx); - if (xla_arg.kind != XlaCompiler::Argument::kConstant) { + if (xla_arg.kind == XlaArgument::kResource) { + mlir::Type element_type; + TF_RETURN_IF_ERROR(ConvertDataType(xla_arg.type, builder, &element_type)); + auto resource_shape = absl::get(xla_arg.shape).dim_sizes(); + llvm::SmallVector resource_subtype_shape( + resource_shape.begin(), resource_shape.end()); + auto resource_subtype = + mlir::RankedTensorType::get(resource_subtype_shape, element_type); + auto resource_type = + mlir::TF::ResourceType::get({resource_subtype}, builder.getContext()); + + auto tensor_type = mlir_arg.getType().cast(); + if (tensor_type.hasRank()) { + mlir_arg.setType( + mlir::RankedTensorType::get(tensor_type.getShape(), resource_type)); + } else { + mlir_arg.setType(mlir::UnrankedTensorType::get(resource_type)); + } + has_resource_args = true; + } + if (xla_arg.kind != XlaArgument::kConstant) { params.push_back(idx); continue; } @@ -434,22 +482,40 @@ static StatusOr> RewriteWithArgs( args_to_erase.push_back(idx); } + if (has_resource_args) { + llvm::SmallVector updated_argument_types; + updated_argument_types.reserve(main_fn.getNumArguments()); + for (mlir::BlockArgument& arg : main_fn.getArguments()) + updated_argument_types.push_back(arg.getType()); + + main_fn.setType(mlir::FunctionType::get(updated_argument_types, + main_fn.getType().getResults(), + main_fn.getContext())); + } + for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx); + return params; } Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef args, + const Graph& graph, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, std::vector> custom_legalization_passes) { RegisterDialects(); mlir::MLIRContext context; GraphImportConfig config; config.graph_as_function = true; + // Disable shape inference during import as some TensorFlow op fails during + // shape inference with dynamic shaped operands. This in turn causes the + // import to fail. Shape inference during import is going to be removed and + // the shape inference pass is run early in the pass pipeline, shape inference + // during import is not necessary. + config.enable_shape_inference = false; auto module_or = ConvertGraphToMlir(graph, debug_info, flib_def, config, &context); if (!module_or.ok()) return module_or.status(); @@ -457,10 +523,21 @@ Status CompileGraphToXlaHlo( mlir::ModuleOp module = module_or.ValueOrDie().get(); TF_ASSIGN_OR_RETURN(std::vector remaining_params, RewriteWithArgs(module, {args.data(), args.size()})); - llvm::SmallVector arg_shapes; - arg_shapes.reserve(args.size()); - for (unsigned idx : remaining_params) - arg_shapes.push_back(absl::get(args[idx].shape)); + llvm::SmallVector arg_shapes; + arg_shapes.reserve(remaining_params.size()); + for (unsigned idx : remaining_params) { + const auto& arg = args[idx]; + arg_shapes.push_back({absl::get(arg.shape), + /*is_resource=*/arg.kind == XlaArgument::kResource}); + } + + mlir::PassManager pm(&context); + mlir::TF::StandardPipelineOptions tf_options; + mlir::TF::CreateTFStandardPipeline(pm, tf_options); + { + mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); + if (failed(pm.run(module))) return diag_handler.ConsumeStatus(); + } auto status = CompileMlirToXlaHlo( module, arg_shapes, device_type, use_tuple_args, shape_representation_fn, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 24b60dcb346..5c64a65ecbd 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -20,7 +20,10 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -57,7 +60,7 @@ Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, bool return_tuple, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr, std::vector> custom_legalization_passes = {}); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying @@ -65,17 +68,18 @@ Status ConvertMLIRToXlaComputation( Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, llvm::StringRef device_type, bool use_tuple_args, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, std::vector> custom_legalization_passes = {}); // Same as the above but takes input as TensorFlow Graph. +// TODO(lyandy): Allow populating of targets/control outputs. Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef args, + const Graph& graph, llvm::ArrayRef args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, std::vector> custom_legalization_passes = {}); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index dde2408c83a..8a07aab11e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -448,9 +451,6 @@ TEST(CompileGraphToXlaHlo, Basic) { FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); Graph graph(OpRegistry::Global()); - Tensor dummy_tensor(DT_FLOAT, TensorShape({1})); - test::FillValues(&dummy_tensor, {-1.0}); - Node* arg = test::graph::Arg(&graph, 0, DT_FLOAT); test::graph::Retval(&graph, 0, arg); @@ -483,5 +483,60 @@ ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) { status_or_hlo_module.ValueOrDie()->ToString()); } +// Tests a conversion from Graph to MLIR with resource arguments. +TEST(CompileGraphToXlaHlo, Resources) { + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + Graph graph(OpRegistry::Global()); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto val = ops::_Arg(scope.WithOpName("arg0"), DT_FLOAT, 0); + auto var = ops::_Arg(scope.WithOpName("arg1"), DT_RESOURCE, 1); + auto assign = + ops::AssignVariableOp(scope.WithOpName("assign_variable"), var, val); + TF_ASSERT_OK(scope.ToGraph(&graph)); + + XlaCompiler::CompilationResult result; + XlaCompiler::Argument arg0; + arg0.kind = XlaCompiler::Argument::kParameter; + arg0.shape = TensorShape({2}); + XlaCompiler::Argument arg1; + arg1.kind = XlaCompiler::Argument::kResource; + arg1.shape = TensorShape({2}); + arg1.type = DT_FLOAT; + + TF_ASSERT_OK( + CompileGraphToXlaHlo(graph, /*args=*/{arg0, arg1}, "XLA_CPU_JIT", + /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), + /*shape_representation_fn=*/nullptr, &result)); + + EXPECT_EQ(result.outputs.size(), 0); + ASSERT_EQ(result.resource_updates.size(), 1); + const auto& resource_update = result.resource_updates[0]; + EXPECT_EQ(resource_update.input_index, 1); + EXPECT_EQ(resource_update.modified, true); + EXPECT_EQ(resource_update.shape, TensorShape({2})); + EXPECT_EQ(resource_update.type, DT_FLOAT); + + const xla::HloModuleConfig module_config( + result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + result.computation->proto(), module_config); + ASSERT_TRUE(status_or_hlo_module.ok()); + + constexpr char expected_hlo_module_string[] = + R"(HloModule main.4, input_output_alias={ {0}: (1, {}, may_alias) } + +ENTRY %main.4 (Arg_0.1: f32[2], Arg_1.2: f32[2]) -> (f32[2]) { + %Arg_1.2 = f32[2]{0} parameter(1) + %Arg_0.1 = f32[2]{0} parameter(0) + ROOT %tuple.3 = (f32[2]{0}) tuple(f32[2]{0} %Arg_0.1) +} + +)"; + + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc index bf0b3b75ace..81892934efe 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -25,6 +26,8 @@ limitations under the License. #include "llvm/Support/Regex.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/common_runtime/device.h" @@ -155,4 +158,19 @@ mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, llvm::formatv("unsupported '{0}' attribute", kDevicesAttr)); } +mlir::LogicalResult GetDeviceOrdinalFromDeviceString(mlir::Location loc, + llvm::StringRef device, + int64_t* device_ordinal) { + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullName( + absl::string_view(device.data(), device.size()), &parsed_name)) + return mlir::emitError(loc) << "invalid device '" << device << "'"; + + if (!parsed_name.has_id) + return mlir::emitError(loc) << "device '" << device << "' has no id"; + + *device_ordinal = parsed_name.id; + return mlir::success(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/device_util.h index 893e118024c..14e48bf7710 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_ #include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" @@ -41,6 +42,12 @@ void AddDevicesToOp(mlir::Operation* op, const DeviceSet* device_set); mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, mlir::TF::RuntimeDevices* devices); +// Parses a device string and returns its ordinal (id). This will return an +// error if the device string is invalid or has no id. +mlir::LogicalResult GetDeviceOrdinalFromDeviceString(mlir::Location loc, + llvm::StringRef device, + int64_t* device_ordinal); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_DEVICE_UTIL_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index bc849e1d116..1da1f5973f6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -205,5 +205,47 @@ TEST(DeviceUtilTest, GetGpuDeviceMetadata) { ASSERT_FALSE(meta_1.hasValue()); } +TEST(DeviceUtilTest, GetDeviceOrdinalFromDeviceString) { + const std::string tpu0 = "/job:worker/replica:0/task:0/device:TPU:0"; + const std::string tpu1 = "/job:worker/replica:0/task:0/device:TPU:1"; + + mlir::MLIRContext context; + auto unknown_loc = mlir::UnknownLoc::get(&context); + + int64_t device_ordinal0 = -1; + mlir::LogicalResult result0 = + GetDeviceOrdinalFromDeviceString(unknown_loc, tpu0, &device_ordinal0); + EXPECT_TRUE(mlir::succeeded(result0)); + EXPECT_EQ(device_ordinal0, 0); + + int64_t device_ordinal1 = -1; + mlir::LogicalResult result1 = + GetDeviceOrdinalFromDeviceString(unknown_loc, tpu1, &device_ordinal1); + EXPECT_TRUE(mlir::succeeded(result1)); + EXPECT_EQ(device_ordinal1, 1); +} + +TEST(DeviceUtilTest, GetDeviceOrdinalFromDeviceStringInvalid) { + mlir::MLIRContext context; + auto unknown_loc = mlir::UnknownLoc::get(&context); + + int64_t device_ordinal = -1; + mlir::LogicalResult result = GetDeviceOrdinalFromDeviceString( + unknown_loc, "bad_device", &device_ordinal); + EXPECT_TRUE(mlir::failed(result)); +} + +TEST(DeviceUtilTest, GetDeviceOrdinalFromDeviceStringNoId) { + const std::string tpu_no_id = "/job:worker/replica:0/task:0/device:TPU"; + + mlir::MLIRContext context; + auto unknown_loc = mlir::UnknownLoc::get(&context); + + int64_t device_ordinal = -1; + mlir::LogicalResult result = + GetDeviceOrdinalFromDeviceString(unknown_loc, tpu_no_id, &device_ordinal); + EXPECT_TRUE(mlir::failed(result)); +} + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h index 4feb3837357..b5f2acc581d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.h @@ -27,7 +27,7 @@ limitations under the License. namespace mlir { // TensorFlow's Status is used for error reporting back to callers. -using tensorflow::Status; +using ::tensorflow::Status; // Diagnostic handler that collects all the diagnostics reported and can produce // a Status to return to callers. This is for the case where MLIR functions are diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index 852bc72d7de..ad9ddb277d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -82,7 +82,7 @@ Status ConvertLocation(mlir::Location inst_loc, if (locations.size() <= 1) return errors::InvalidArgument("expected experimental debuf info."); // skip the first one, which is the name of the node_def. - for (int i = 0; i < locations.size() - 1; ++i) { + for (int i = 0, end = locations.size() - 1; i < end; ++i) { TF_RETURN_IF_ERROR(ConvertLocation(locations[i], debug_info)); } } @@ -121,6 +121,20 @@ Status ConvertAttribute(const mlir::TF::ShapeAttr& attr, AttrValue* value) { return Status::OK(); } +Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) { + value->mutable_func()->set_name(attr.getValue().str()); + return Status::OK(); +} + +Status ConvertAttribute(const mlir::TF::FuncAttr& attr, AttrValue* value) { + TF_RETURN_IF_ERROR( + ConvertAttribute(attr.GetName().cast(), value)); + TF_RETURN_IF_ERROR(ConvertAttributes(attr.GetAttrs().getValue(), + /*attrs_to_ignore=*/{}, + value->mutable_func()->mutable_attr())); + return Status::OK(); +} + Status ConvertAttribute(const mlir::StringAttr& attr, AttrValue* value) { absl::string_view attr_value(attr.getValue().data(), attr.getValue().size()); switch (mangling_util::GetMangledKind(attr_value)) { @@ -160,11 +174,6 @@ Status ConvertAttribute(const mlir::UnitAttr& attr, AttrValue* value) { return Status::OK(); } -Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) { - value->mutable_func()->set_name(std::string(attr.getValue())); - return Status::OK(); -} - Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { auto* list = value->mutable_list(); for (mlir::Attribute a : attr.getValue()) { @@ -218,25 +227,13 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, AttrValue* value) { return Status::OK(); } -// Updates NodeDef constructed out of an MLIR If op to map it to either -// TensorFlow StatelessIf or If op depending on the additional attribute. -void UpdateCompositeIfOp(NodeDef* node_def) { +// Updates NodeDef constructed out of an MLIR Case/IfW/While op to map it to +// either TensorFlow StatelessX or X op depending on the additional attribute. +void UpdateCompositeOp(NodeDef* node_def) { auto it = node_def->mutable_attr()->find("is_stateless"); if (it != node_def->attr().end()) { if (it->second.b()) { - *node_def->mutable_op() = "StatelessIf"; - } - node_def->mutable_attr()->erase(it); - } -} - -// Updates NodeDef constructed out of an MLIR While op to map it to either -// TensorFlow StatelessWhile or While op depending on the additional attribute. -void UpdateCompositeWhileOp(NodeDef* node_def) { - auto it = node_def->mutable_attr()->find("is_stateless"); - if (it != node_def->attr().end()) { - if (it->second.b()) { - *node_def->mutable_op() = "StatelessWhile"; + *node_def->mutable_op() = "Stateless" + node_def->op(); } node_def->mutable_attr()->erase(it); } @@ -343,8 +340,9 @@ StatusOr> GetOperationNodeDef( TF_RETURN_IF_ERROR(ConvertLocation( inst->getLoc(), node_def->mutable_experimental_debug_info())); - if (node_def->op() == "If") UpdateCompositeIfOp(node_def.get()); - if (node_def->op() == "While") UpdateCompositeWhileOp(node_def.get()); + if (node_def->op() == "Case") UpdateCompositeOp(node_def.get()); + if (node_def->op() == "If") UpdateCompositeOp(node_def.get()); + if (node_def->op() == "While") UpdateCompositeOp(node_def.get()); return node_def; } @@ -372,8 +370,8 @@ Status ConvertAttributes( AttrValue value; switch (attr.getKind()) { case mlir::StandardAttributes::SymbolRef: { - auto func_attr = attr.cast(); - value.mutable_func()->set_name(std::string(func_attr.getValue())); + TF_RETURN_IF_ERROR( + ConvertAttribute(attr.cast(), &value)); func_call_attrs[string(name)] = value; continue; } @@ -415,6 +413,12 @@ Status ConvertAttributes( TF_RETURN_IF_ERROR( ConvertAttribute(attr.cast(), &value)); break; + case static_cast(mlir::TF::AttrKind::FUNC): { + TF_RETURN_IF_ERROR( + ConvertAttribute(attr.cast(), &value)); + func_call_attrs[string(name)] = value; + continue; + } // AffineMap kind is not implemented. case mlir::StandardAttributes::AffineMap: return errors::Unimplemented("AffineMap attribute (needed for '", @@ -503,7 +507,7 @@ Status SetSizeAttribute(absl::string_view name, size_t size, // This should be extremely rare as it means we are adding the same // attribute multiple times/have some redundancy in representing this // attribute. - int64 actual_size = result.first->second.i(); + size_t actual_size = result.first->second.i(); // Just check via string output as we shouldn't get here and if we do they // should be trivially the same, else fail. if (actual_size != size) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index f884b75bce1..843d491c330 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -149,7 +149,8 @@ Status GetTPUDevices( std::next(system_devices.begin()), system_devices.end())) { auto host_tpu_devices = lookup(device_spec); // Check number of TPU devices per host all match. - if (num_tpus_per_host != host_tpu_devices.size()) + const int64 host_tpu_devices_size = host_tpu_devices.size(); + if (num_tpus_per_host != host_tpu_devices_size) return errors::InvalidArgument( "expected the number of TPU devices per host to be ", num_tpus_per_host, ", got ", host_tpu_devices.size()); @@ -354,7 +355,8 @@ GetGeneralTPUExecutionDeviceAssignment( const int expected_device_assignment_size = num_replicas * num_cores_per_replica * kTPUTopologyRank; - if (device_assignment_attr.size() != expected_device_assignment_size) + const int device_assignment_attr_size = device_assignment_attr.size(); + if (device_assignment_attr_size != expected_device_assignment_size) return errors::InvalidArgument( "length of '", kDeviceAssignmentAttr, "' must be 'num_replicas' * 'num_cores_per_replica' * ", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index 083a5abf840..a3f8e833ae3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -242,7 +242,8 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( cluster_func.getLoc(), sharding, input_value, builder, &tiled_inputs); if (mlir::failed(result)) return mlir::failure(); - if (tiled_inputs.size() != num_cores_per_replica) + const int64 tiled_inputs_size = tiled_inputs.size(); + if (tiled_inputs_size != num_cores_per_replica) cluster_func.emitError(llvm::formatv( "incorrect {0}-th tiled input sharding received. " "Product of tile sharding splits({1}) must be equal to " @@ -376,7 +377,8 @@ mlir::LogicalResult HandleTileShardedOutputs( llvm::SmallVector new_outputs; new_outputs.reserve(num_splits); - for (int i = 0; i < outputs_to_merge.size(); i = i + num_splits) { + for (int i = 0, end = outputs_to_merge.size(); i < end; + i = i + num_splits) { mlir::TF::ConcatOp concat_op; auto result = CreateConcatOp(concat_dimension, location, diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 8cfdfd01120..caac8ea1eeb 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -121,7 +121,7 @@ int main(int argc, char** argv) { mlir::MLIRContext context; auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport( - input_filename, tags, exported_names, &context); + input_filename, tags, exported_names, &context, upgrade_legacy); if (!module_or.status().ok()) return 1; module_or.ConsumeValueOrDie()->print(output->os()); diff --git a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc index 9ba875cdce4..331bed09dce 100644 --- a/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc +++ b/tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc @@ -25,8 +25,7 @@ namespace tfjs { // TFJSDialect //===----------------------------------------------------------------------===// -TFJSDialect::TFJSDialect(MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void TFJSDialect::initialize() { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.cc.inc" diff --git a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc index b7e95629062..4e24007a8c6 100644 --- a/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc +++ b/tensorflow/compiler/mlir/tfjs/tf_tfjs_passes.cc @@ -45,7 +45,7 @@ void AddTFToTFJSConversionPasses(mlir::OpPassManager* pm) { // raise to executor dialect in order to use GraphDef converter pm->addNestedPass( mlir::CreateFunctionalToExecutorDialectConversionPass()); - pm->addNestedPass(mlir::CreateBreakUpIslandsPass()); + pm->addPass(mlir::CreateBreakUpIslandsPass()); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index b5735f823e4..5befdcdc513 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -1,7 +1,16 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -licenses(["notice"]) +package( + default_visibility = [":friends"], + licenses = ["notice"], # Apache 2.0 +) + +package_group( + name = "friends", + includes = ["//third_party/mlir:subpackages"], + packages = ["//tensorflow/compiler/mlir/..."], +) cc_library( name = "cubin_creator", @@ -50,3 +59,33 @@ tf_cc_binary( "@llvm-project//llvm:Support", ], ) + +tf_cc_binary( + name = "kernel-gen-opt", + srcs = ["tools/kernel-gen-opt/kernel-gen-opt.cc"], + visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen/tests:__pkg__"], + deps = [ + "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_dialect_registration", + "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:MlirOptMain", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + +exports_files(["tf_framework_c_interface.h"]) + +cc_library( + name = "tf_framework_c_interface", + srcs = ["tf_framework_c_interface.cc"], + hdrs = ["tf_framework_c_interface.h"], + deps = [ + "//tensorflow/core:framework", + "@llvm-project//mlir:mlir_runner_utils", + ], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc index 1f511e27d9e..82b0e613f90 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -278,7 +278,8 @@ StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( mlir::OwningModuleRef kernel_module = xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie(); - auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext); if (!llvmModule) { return InternalError("Could not translate MLIR module to NVVM"); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD new file mode 100644 index 00000000000..3a28d4815d2 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -0,0 +1,47 @@ +load("//third_party/mlir:tblgen.bzl", "gentbl") + +package( + default_visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen:friends"], + licenses = ["notice"], # Apache 2.0 +) + +gentbl( + name = "tf_framework_ops_inc_gen", + tbl_outs = [ + ("-gen-op-decls", "tf_framework_ops.h.inc"), + ("-gen-op-defs", "tf_framework_ops.cc.inc"), + ("-gen-dialect-decls", "tf_framework_dialect.h.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "tf_framework_ops.td", + td_srcs = [ + "tf_framework_ops.td", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", + ], +) + +cc_library( + name = "tf_framework_ops", + srcs = [ + "tf_framework_ops.cc", + "tf_framework_ops.cc.inc", + "tf_framework_ops.h.inc", + ], + hdrs = ["tf_framework_ops.h"], + deps = [ + ":tf_framework_ops_inc_gen", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SideEffects", + ], +) + +cc_library( + name = "tf_framework_dialect_registration", + srcs = ["dialect_registration.cc"], + deps = [ + ":tf_framework_ops", + "@llvm-project//mlir:IR", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/dialect_registration.cc new file mode 100644 index 00000000000..a2e5955b570 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/dialect_registration.cc @@ -0,0 +1,21 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" + +// Static initialization for TF Framework dialect registration. +static mlir::DialectRegistration< + mlir::kernel_gen::tf_framework::TFFrameworkDialect> + tf_framework_ops; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc new file mode 100644 index 00000000000..5b7a19a3eac --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the tf_framework dialect. + +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" + +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { + +void TFFrameworkDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" + >(); + addTypes(); +} + +/// Parse a type registered to this dialect. +Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const { + StringRef keyword; + if (parser.parseKeyword(&keyword)) return Type(); + + if (keyword == "op_kernel_context") { + return OpKernelContextType::get(getContext()); + } + + parser.emitError(parser.getNameLoc(), "unknown TF Framework type: ") + << keyword; + return Type(); +} + +/// Print a type registered to this dialect. +void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const { + switch (type.getKind()) { + case TFFrameworkTypes::OpKernelContextType: + os << "op_kernel_context"; + return; + default: + llvm_unreachable("unexpected TF Framework type kind"); + } +} + +template +LogicalResult Verify(OpTy op) { + return success(); +} + +//===----------------------------------------------------------------------===// +// AllocRawOp +//===----------------------------------------------------------------------===// +template <> +LogicalResult Verify(AllocRawOp op) { + // Check that the total number of operands matches the number of dynamic + // dimensions specified in the memref type. + unsigned result_dyn_dims = op.getType().getNumDynamicDims(); + unsigned dyn_sizes_count = op.dyn_sizes().size(); + if (dyn_sizes_count != result_dyn_dims) + return op.emitOpError() + << "`dyn_sizes` count " << dyn_sizes_count + << " does not match dynamic dimensions count in the result type" + << op.getType(); + return success(); +} + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h new file mode 100644 index 00000000000..8d6e433d9b9 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the TFFramework dialect. +// +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { + +namespace TFFrameworkTypes { +enum Kind { + OpKernelContextType = Type::FIRST_TF_FRAMEWORK_TYPE, +}; +} // namespace TFFrameworkTypes + +/// OpKernelContextType corresponds to C++ class OpKernelContext defined in +/// tensorflow/core/framework/op_kernel.h +class OpKernelContextType + : public Type::TypeBase { + public: + using Base::Base; + + static OpKernelContextType get(MLIRContext *context) { + return Base::get(context, TFFrameworkTypes::Kind::OpKernelContextType); + } + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == TFFrameworkTypes::Kind::OpKernelContextType; + } +}; + +#define GET_OP_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.h.inc" +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h.inc" + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_IR_TF_FRAMEWORK_OPS_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td new file mode 100644 index 00000000000..bc390a5aaa5 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -0,0 +1,125 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for TF Framework ops. + +#ifndef TF_FRAMEWORK_OPS +#define TF_FRAMEWORK_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def TFFramework_Dialect : Dialect { + let name = "tf_framework"; + + let summary = "Types and operations for tf_framework dialect"; + let description = [{ + This dialect contains operations and types for that correspond to + TensorFlow C++ Framework. + }]; + let cppNamespace = "kernel_gen::tf_framework"; +} + +def TFFramework_OpKernelContextType : DialectType()">, + "op_kernel_construction">, + BuildableType<"$_builder.getType<::mlir::kernel_gen::tf_framework::OpKernelContextType>()"> { + let typeDescription = [{ + OpKernelContextType corresponds to C++ class OpKernelContext defined in + tensorflow/core/framework/op_kernel.h + }]; +} + +// Base class for TF Framework dialect ops. +class TFFramework_Op traits = []> : + Op { + let verifier = "return Verify<$cppClass>(*this);"; +} + +//===----------------------------------------------------------------------===// +// AllocRawOp +//===----------------------------------------------------------------------===// +def TFFramework_AllocRawOp : TFFramework_Op<"alloc_raw", + [MemoryEffects<[MemAlloc]>]> { + let summary = "allocation of tensors that uses TF Framework"; + let description = [{ + Allocation of tensors during kernel execution in the Compute method. + + This should be used to allocate any temporary or output memref. + Corresponds to `Allocator::AllocateRaw` in + tensorflow/core/framework/allocator.h. + }]; + + let arguments = (ins TFFramework_OpKernelContextType:$ctx, + Variadic:$dyn_sizes); + let results = (outs Res]>:$result); + + let builders = [ + OpBuilder<[{ + OpBuilder &builder, OperationState &result, MemRefType memref_type, + Value ctx + }], [{ + result.addOperands(ctx); + result.types.push_back(memref_type); + }]>, + + OpBuilder<[{ + OpBuilder &builder, OperationState &result, MemRefType memref_type, + Value ctx, ValueRange dyn_sizes + }], [{ + build(builder, result, memref_type, ctx); + result.addOperands(dyn_sizes); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getType() { return getResult().getType().cast(); } + }]; + let assemblyFormat = [{ + `(` $ctx (`,` $dyn_sizes^ )? `)` attr-dict `:` type($result) + }]; +} + +//===----------------------------------------------------------------------===// +// DeallocRawOp +//===----------------------------------------------------------------------===// +def TFFramework_DeallocRawOp : TFFramework_Op<"dealloc_raw", + [MemoryEffects<[MemFree]>]> { + let summary = "deallocation of tensors that uses TF Framework"; + let description = [{ + Deallocation of tensors during kernel execution in the Compute method. + + This should be used to deallocate any temporary memref that was allocated + with `tf_framework.alloc_raw`. + Corresponds to `Allocator::DeallocateRaw` in + tensorflow/core/framework/allocator.h. + }]; + + let arguments = (ins TFFramework_OpKernelContextType:$ctx, + Arg:$memref); + let assemblyFormat = "`(` $ctx `,` $memref `)` attr-dict `:` type($memref)"; +} + +//===----------------------------------------------------------------------===// +// NullContextOp +//===----------------------------------------------------------------------===// +def TFFramework_NullContextOp : TFFramework_Op<"null_context", + [NoSideEffect]> { + let summary = "Creates a fake TF context that will be lowered to nullptr"; + let description = [{Needed for testing}]; + let results = (outs TFFramework_OpKernelContextType:$result); + let assemblyFormat = "`(` `)` attr-dict `:` type($result)"; +} + +#endif // TF_FRAMEWORK_OPS diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc new file mode 100644 index 00000000000..e75db59d885 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.cc @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h" + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { +namespace { + +using tensorflow::Allocator; + +Allocator* GetAllocator(void* op_kernel_ctx) { + auto* ctx = static_cast(op_kernel_ctx); + // TODO(pifon): Figure out how to set AllocatorAttributes correctly. + tensorflow::AllocatorAttributes attrs; + return ctx->get_allocator(attrs); +} + +} // namespace + +extern "C" void* _mlir_ciface_tf_alloc_raw(void* op_kernel_ctx, + size_t num_bytes) { + return GetAllocator(op_kernel_ctx) + ->AllocateRaw(Allocator::kAllocatorAlignment, num_bytes); +} + +extern "C" void _mlir_ciface_tf_dealloc_raw(void* op_kernel_ctx, void* ptr) { + GetAllocator(op_kernel_ctx)->DeallocateRaw(ptr); +} + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h new file mode 100644 index 00000000000..143ebc95932 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_framework_c_interface.h @@ -0,0 +1,35 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TESTS_TF_FRAMEWORK_C_INTERFACE_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TESTS_TF_FRAMEWORK_C_INTERFACE_H_ + +#include "mlir/ExecutionEngine/RunnerUtils.h" // from @llvm-project + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { + +extern "C" MLIR_RUNNERUTILS_EXPORT void* _mlir_ciface_tf_alloc_raw( + void* op_kernel_ctx, size_t num_bytes); + +extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_tf_dealloc_raw( + void* op_kernel_ctx, void* ptr); + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TESTS_TF_FRAMEWORK_C_INTERFACE_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc new file mode 100644 index 00000000000..c1af35617b1 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc @@ -0,0 +1,122 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/IR/AsmState.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" + +// NOLINTNEXTLINE +static llvm::cl::opt inputFilename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt outputFilename( + "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt splitInputFile( + "split-input-file", + llvm::cl::desc("Split the input file into pieces and process each " + "chunk independently"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt verifyDiagnostics( + "verify-diagnostics", + llvm::cl::desc("Check that emitted diagnostics match " + "expected-* lines on the corresponding line"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt verifyPasses( + "verify-each", + llvm::cl::desc("Run the verifier after each transformation pass"), + llvm::cl::init(true)); + +// NOLINTNEXTLINE +static llvm::cl::opt allowUnregisteredDialects( + "allow-unregistered-dialect", + llvm::cl::desc("Allow operation with no registered dialects"), + llvm::cl::init(false)); + +// NOLINTNEXTLINE +static llvm::cl::opt showDialects( + "show-dialects", llvm::cl::desc("Print the list of registered dialects"), + llvm::cl::init(false)); + +int main(int argc, char **argv) { + mlir::registerAllDialects(); + mlir::registerAllPasses(); + + mlir::mhlo::registerAllDialects(); + mlir::kernel_gen::registerKernelGenPasses(); + + llvm::InitLLVM y(argc, argv); + + // Register any pass manager command line options. + mlir::registerAsmPrinterCLOptions(); + mlir::registerPassManagerCLOptions(); + mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); + + // Parse pass names in main to ensure static initialization completed. + llvm::cl::ParseCommandLineOptions(argc, argv, + "MLIR modular optimizer driver\n"); + + if (showDialects) { + mlir::MLIRContext context; + llvm::outs() << "Registered Dialects:\n"; + for (mlir::Dialect *dialect : context.getRegisteredDialects()) { + llvm::outs() << dialect->getNamespace() << "\n"; + } + return 0; + } + + // Set up the input file. + std::string errorMessage; + auto file = mlir::openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + + auto output = mlir::openOutputFile(outputFilename, &errorMessage); + if (!output) { + llvm::errs() << errorMessage << "\n"; + exit(1); + } + + if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, + splitInputFile, verifyDiagnostics, verifyPasses, + allowUnregisteredDialects))) { + return 1; + } + // Keep the output file if the invocation of MlirOptMain was successful. + output->keep(); + return 0; +} diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD new file mode 100644 index 00000000000..b0f22b40f5b --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -0,0 +1,93 @@ +load("//third_party/mlir:tblgen.bzl", "gentbl") + +package( + default_visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen:friends"], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "tf_framework_legalize_to_llvm", + srcs = ["tf_framework_legalize_to_llvm.cc"], + hdrs = ["rewriters.h"], + deps = [ + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "bufferize", + srcs = ["bufferize.cc"], + hdrs = ["rewriters.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + +cc_library( + name = "embed_tf_framework", + srcs = ["embed_tf_framework.cc"], + hdrs = ["rewriters.h"], + deps = [ + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + +gentbl( + name = "kernel_gen_passes_inc_gen", + tbl_outs = [("-gen-pass-decls -name KernelGen", "kernel_gen_passes.h.inc")], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes.td", + td_srcs = ["@llvm-project//mlir:PassBaseTdFiles"], +) + +cc_library( + name = "passes", + srcs = [ + "bufferize_pass.cc", + "embed_tf_framework_pass.cc", + "shape_to_descriptors_pass.cc", + "tf_framework_legalize_to_llvm_pass.cc", + ], + hdrs = ["passes.h"], + deps = [ + ":bufferize", + ":embed_tf_framework", + ":kernel_gen_passes_inc_gen", + ":tf_framework_legalize_to_llvm", + "//tensorflow/compiler/mlir/hlo", + "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", + "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/mlir/hlo:lhlo_legalize_to_llvm", + "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Shape", + "@llvm-project//mlir:ShapeToSCF", + "@llvm-project//mlir:ShapeToStandard", + "@llvm-project//mlir:ShapeTransforms", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc new file mode 100644 index 00000000000..3d5c820e6dd --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize.cc @@ -0,0 +1,110 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for translating mixed IR to buffer form. + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { +namespace kernel_gen { +namespace transforms { + +namespace { + +class TensorFromElementsOpConverter + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + TensorFromElementsOp>::BufferAssignmentOpConversionPattern; + + LogicalResult matchAndRewrite( + TensorFromElementsOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ShapedType result_type = op.getType().cast(); + int number_of_elements = op.elements().size(); + MemRefType memref_type = + MemRefType::get({number_of_elements}, result_type.getElementType()); + Value result = rewriter.create(loc, memref_type); + for (auto operand : llvm::enumerate(operands)) { + Value index = rewriter.create(loc, operand.index()); + rewriter.create(loc, operand.value(), result, index); + } + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +class TensorLoadOpConversion + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + TensorLoadOp>::BufferAssignmentOpConversionPattern; + + LogicalResult matchAndRewrite( + TensorLoadOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + TensorLoadOpAdaptor adaptor(operands); + rewriter.replaceOp(op, {adaptor.memref()}); + return success(); + } +}; + +class ExtractElementOpConversion + : public BufferAssignmentOpConversionPattern { + public: + using BufferAssignmentOpConversionPattern< + ExtractElementOp>::BufferAssignmentOpConversionPattern; + + LogicalResult matchAndRewrite( + ExtractElementOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + ExtractElementOpAdaptor adaptor(operands); + + if (!adaptor.aggregate().getType().isa()) { + return failure(); + } + + rewriter.replaceOpWithNewOp(op, adaptor.aggregate(), + adaptor.indices()); + return success(); + } +}; + +} // namespace + +void populateStandardBufferizePattern(MLIRContext *context, + BufferAssignmentPlacer *bufferAssignment, + TypeConverter *converter, + OwningRewritePatternList *patterns) { + patterns->insert(context, bufferAssignment, + converter); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc new file mode 100644 index 00000000000..ef07c801bc4 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/bufferize_pass.cc @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for translating mixed IR to buffer form. +// Currently it supports MHLO and some operations from the Standard dialect. + +#include + +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +// TODO(herhut) : This could become a real pattern in bufferize pass. What we +// would need to do is insert a copy to model the semantics correctly. The same +// is true for the TensorLoad pattern that is already in there. Then buffer +// assignment free insertion and copy removal should clean this up for us. +// +// This patten erases `tensor_store(src_unranked_tensor, dst_unranked_memref)` +// op and replaces the result of the defining op produced `dst_unranked_memref` +// with the rewritten `src_unranked_tensor`. +class UnrankedTensorStoreTestOnlyPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::TensorStoreOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + rewriter.replaceOp(op.memref().getDefiningOp(), op.tensor()); + rewriter.replaceOp(op, {}); + return success(); + } +}; + +struct BufferizePass : public BufferizePassBase { + public: + void runOnOperation() override { + OwningRewritePatternList patterns; + auto& context = getContext(); + ConversionTarget target(context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addDynamicallyLegalOp([&](TensorStoreOp op) { + return !op.tensor().getType().isa(); + }); + + BufferAssignmentTypeConverter converter; + auto typesAreLegal = [&converter](Operation* op) { + return converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()); + }; + target.addDynamicallyLegalOp([&](FuncOp op) { + auto inputs = op.getType().getInputs(); + auto results = op.getType().getResults(); + return converter.isLegal(inputs) && converter.isLegal(results) && + converter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp(typesAreLegal); + target.addDynamicallyLegalOp(typesAreLegal); + + auto module = getOperation(); + WalkResult result = module.walk([&](FuncOp func) -> WalkResult { + BufferAssignmentPlacer bufferAssignment(func); + OwningRewritePatternList patterns; + mhlo::populateHLOToLHLOConversionPattern( + func.getContext(), &bufferAssignment, &converter, &patterns); + populateWithBufferAssignmentOpConversionPatterns< + ReturnOp, ReturnOp, lmhlo::CopyOp, + /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment, + &converter, &patterns); + populateStandardBufferizePattern(func.getContext(), &bufferAssignment, + &converter, &patterns); + patterns.insert(func.getContext()); + + return applyPartialConversion(func, target, patterns); + }); + if (result.wasInterrupted()) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr > CreateBufferizePass() { + return std::make_unique(); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc new file mode 100644 index 00000000000..aa02aefa9d2 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc @@ -0,0 +1,127 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { +namespace { + +// Prepends argument type list of the function with an OpKernelContextType arg. +class FuncOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + FuncOp func, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Convert function arguments using the provided TypeConverter. + auto func_type = func.getType(); + TypeConverter::SignatureConversion conversion(func_type.getNumInputs()); + + conversion.addInputs(OpKernelContextType::get(rewriter.getContext())); + for (auto arg_type : llvm::enumerate(func_type.getInputs())) { + conversion.addInputs(arg_type.index(), arg_type.value()); + } + + TypeConverter type_converter; + if (failed(rewriter.convertRegionTypes(&func.getBody(), type_converter, + &conversion))) { + return failure(); + } + + // Update the signature of the function. + rewriter.updateRootInPlace(func, [&] { + func.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), + func_type.getResults())); + }); + return success(); + } +}; + +// Converts std.alloc to tf_framework.alloc_raw using OpKernelContextType arg of +// the parent function. +class AllocOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + AllocOp alloc, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto func = alloc.getParentOfType(); + if (func.getNumArguments() == 0) { + return failure(); + } + Value ctx = func.getArgument(0); + if (!ctx.getType().isa()) { + return failure(); + } + // Symbolic operands that bind to the symbols of the memref's layout map are + // not supported by AllocRawOp. + if (alloc.getNumSymbolicOperands() != 0) { + return failure(); + } + rewriter.replaceOpWithNewOp(alloc, alloc.getType(), ctx, + operands); + return success(); + } +}; + +// Converts std.dealloc to tf_framework.dealloc_raw using OpKernelContextType +// arg of the parent function. +class DeallocOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + DeallocOp dealloc, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + FuncOp func = dealloc.getParentOfType(); + if (func.getNumArguments() == 0) { + return failure(); + } + Value ctx = func.getArgument(0); + if (!ctx.getType().isa()) { + return failure(); + } + // Operand with no layout is expected. + auto operand_memref_type = dealloc.memref().getType().cast(); + if (!operand_memref_type.getAffineMaps().empty()) { + return failure(); + } + DeallocOp::Adaptor transformed(operands); + rewriter.replaceOpWithNewOp(dealloc, ctx, + transformed.memref()); + return success(); + } +}; + +} // namespace + +void PopulateEmbedTFFrameworkConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert( + context); +} + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc new file mode 100644 index 00000000000..a0cfcae65d1 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc @@ -0,0 +1,77 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +static constexpr StringRef kTFEntry = "tf_entry"; + +// The pass rewrites the function marked with `tf_entry` attribute. +// * adds tf_framework::OpKernelContextType argument to the function, +// * std.alloc becomes tf_framework.alloc_raw, +// * std.dealloc becomes tf_framework.dealloc_raw. +class EmbedTFFrameworkPass + : public EmbedTFFrameworkPassBase { + public: + void runOnOperation() override { + ModuleOp m = getOperation(); + + // Populate patterns. + OwningRewritePatternList patterns; + PopulateEmbedTFFrameworkConversionPatterns(m.getContext(), &patterns); + + // Set target. + ConversionTarget target(getContext()); + target.addLegalDialect(); + + target.addDynamicallyLegalOp([&](FuncOp op) { + if (!op.getAttrOfType(kTFEntry)) { + return true; + } + FunctionType func_type = op.getType(); + return func_type.getNumInputs() > 0 && + func_type.getInput(0).isa(); + }); + target.addDynamicallyLegalOp([](Operation* op) { + return !op->getParentOfType().getAttrOfType(kTFEntry); + }); + + if (failed(applyPartialConversion(m, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr > createEmbedTFFrameworkPass() { + return std::make_unique(); +} + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h new file mode 100644 index 00000000000..e65d8402fb2 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_PASSES_H_ + +#include + +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { + +// Test pass for applying TF Framework -> LLVM patterns. +std::unique_ptr > +createTestTFFrameworkLegalizeToLLVMPass(); + +// Pass to replace some of the Standard ops with TF Framework ops. +// * adds tf_framework::OpKernelContextType argument to the function +// * std.alloc becomes tf_framework.alloc_raw +// * std.dealloc becomes tf_framework.dealloc_raw +std::unique_ptr > createEmbedTFFrameworkPass(); + +} // namespace tf_framework + +namespace transforms { + +// Pass to tranform shape computations in shape dialect to standard and scf +// using memref descriptors. +std::unique_ptr > CreateShapeToDescriptorsPass(); + +// Pass to tranform computations on values to their corresponding parts on +// buffers. +std::unique_ptr > CreateBufferizePass(); + +} // namespace transforms + +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td new file mode 100644 index 00000000000..6a0e328f212 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TF_FRAMEWORK_PASSES +#define TF_FRAMEWORK_PASSES + +include "mlir/Pass/PassBase.td" + +def TestTFFrameworkLegalizeToLLVMPass + : Pass<"test-tf-framework-legalize-to-llvm", "ModuleOp"> { + let summary = "Test pass for applying TF Framework -> LLVM patterns."; + let constructor = "tf_framework::createTestTFFrameworkLegalizeToLLVMPass()"; +} + +def EmbedTFFrameworkPass : Pass<"embed-tf-framework", "ModuleOp"> { + let summary = "Pass to embed TF Framework for allocation and error reporting"; + let constructor = "tf_framework::createEmbedTFFrameworkPass()"; +} + +def ShapeToDescriptorsPass : Pass<"test-shape-to-descriptors", "ModuleOp"> { + let summary = "Pass to transform shape computations to descriptors"; + let constructor = "transforms::CreateShapeToDescriptorsPass()"; +} + +def BufferizePass : Pass<"test-bufferize", "ModuleOp"> { + let summary = "Pass to transform operations on values to buffer based ones"; + let constructor = "transforms::CreateBufferizePass()"; +} + +#endif // TF_FRAMEWORK_PASSES diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h new file mode 100644 index 00000000000..4efc1e95bc8 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_REWRITERS_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_REWRITERS_H_ + +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { + +class BufferAssignmentPlacer; +class LLVMTypeConverter; +class MLIRContext; +class OwningRewritePatternList; +class TypeConverter; + +namespace kernel_gen { +namespace tf_framework { + +/// Collects a set of patterns to convert from the TF Framework dialect to LLVM. +void PopulateTFFrameworkToLLVMConversionPatterns( + LLVMTypeConverter *converter, OwningRewritePatternList *patterns); + +/// Collects a set of patterns to embed TF Framework. +void PopulateEmbedTFFrameworkConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); + +} // namespace tf_framework + +namespace transforms { + +/// Collects a set of patterns that bufferize operations from the standard +/// dialect. +void populateStandardBufferizePattern(MLIRContext *context, + BufferAssignmentPlacer *bufferAssignment, + TypeConverter *converter, + OwningRewritePatternList *patterns); +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_KERNEL_GEN_TRANSFORMS_REWRITERS_H_ diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc new file mode 100644 index 00000000000..28d3647bb63 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_to_descriptors_pass.cc @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file combines patterns for lowering shape dialect to standard ops, +// structured control flow and descriptors. + +#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" // from @llvm-project +#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { +namespace transforms { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +struct ShapeToDescriptorsPass + : public ShapeToDescriptorsPassBase { + public: + void runOnOperation() override { + MLIRContext &ctx = getContext(); + + // Setup target legality. + ConversionTarget target(ctx); + target.addIllegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + // Setup conversion patterns. + OwningRewritePatternList patterns; + populateShapeRewritePatterns(&ctx, patterns); + populateShapeToStandardConversionPatterns(patterns, &ctx); + populateShapeToSCFConversionPatterns(patterns, &ctx); + + // Apply conversion. + auto module = getOperation(); + if (failed(applyPartialConversion(module, target, patterns))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr > CreateShapeToDescriptorsPass() { + return std::make_unique(); +} + +} // namespace transforms +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc new file mode 100644 index 00000000000..3ce111ff3ff --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -0,0 +1,201 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { +namespace { + +using LLVM::LLVMFuncOp; +using LLVM::LLVMType; + +static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc_raw"; +static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc_raw"; + +/// Base class for patterns converting TF Framework ops to function calls. +template +class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern { + public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + // Attempts to find function symbol in the module, adds it if not found. + FlatSymbolRefAttr getOrInsertTFFunction(PatternRewriter &rewriter, + Operation *op) const { + ModuleOp module = op->getParentOfType(); + StringRef tf_func_name = GetFuncName(); + auto tf_func = module.lookupSymbol(tf_func_name); + if (!tf_func) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + auto func_type = GetFuncType(); + tf_func = rewriter.create(rewriter.getUnknownLoc(), + tf_func_name, func_type); + } + return SymbolRefAttr::get(tf_func_name, rewriter.getContext()); + } + + protected: + virtual StringRef GetFuncName() const = 0; + virtual LLVMType GetFuncType() const = 0; +}; + +class AllocRawOpConverter : public ConvertToLLVMCallOpPattern { + public: + using ConvertToLLVMCallOpPattern::ConvertToLLVMCallOpPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + AllocRawOp alloc_raw_op = cast(op); + AllocRawOp::Adaptor transformed(operands); + + MemRefType memref_type = alloc_raw_op.getType(); + + // Get memref descriptor sizes. + SmallVector sizes; + getMemRefDescriptorSizes(loc, memref_type, + llvm::to_vector<4>(transformed.dyn_sizes()), + rewriter, sizes); + // Get memory block size in bytes. + Value num_bytes = getCumulativeSizeInBytes( + loc, memref_type.getElementType(), sizes, rewriter); + + // Insert function call. + FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op); + Value allocated_byte_ptr = + rewriter + .create( + loc, getVoidPtrType(), tf_func_ref, + llvm::makeArrayRef({transformed.ctx(), num_bytes})) + .getResult(0); + + MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor( + loc, rewriter, memref_type, allocated_byte_ptr, sizes); + + // Return the final value of the descriptor. + rewriter.replaceOp(op, {memRefDescriptor}); + return success(); + } + + protected: + StringRef GetFuncName() const override { return kCInterfaceAlloc; } + + LLVMType GetFuncType() const override { + LLVMType llvm_void_ptr_type = getVoidPtrType(); + return LLVM::LLVMType::getFunctionTy( + llvm_void_ptr_type, + llvm::makeArrayRef({llvm_void_ptr_type, getIndexType()}), + /*isVarArg=*/false); + } + + private: + MemRefDescriptor CreateMemRefDescriptor(Location loc, + ConversionPatternRewriter &rewriter, + MemRefType memref_type, + Value allocated_byte_ptr, + ArrayRef sizes) const { + auto memref_desc = MemRefDescriptor::undef( + rewriter, loc, typeConverter.convertType(memref_type)); + + // TF AllocateRaw returns aligned pointer => AllocatedPtr == AlignedPtr. + Value allocated_type_ptr = rewriter.create( + loc, getElementPtrType(memref_type), allocated_byte_ptr); + memref_desc.setAllocatedPtr(rewriter, loc, allocated_type_ptr); + memref_desc.setAlignedPtr(rewriter, loc, allocated_type_ptr); + memref_desc.setConstantOffset(rewriter, loc, 0); + + if (memref_type.getRank() == 0) { + return memref_desc; + } + + // Compute strides and populate descriptor `size` and `stride` fields. + Value stride_carried = createIndexConstant(rewriter, loc, 1); + for (int pos = sizes.size() - 1; pos >= 0; --pos) { + Value size = sizes[pos]; + memref_desc.setSize(rewriter, loc, pos, size); + memref_desc.setStride(rewriter, loc, pos, stride_carried); + // Update stride + if (pos > 0) { + stride_carried = + rewriter.create(loc, stride_carried, size); + } + } + return memref_desc; + } +}; + +class DeallocRawOpConverter : public ConvertToLLVMCallOpPattern { + public: + using ConvertToLLVMCallOpPattern::ConvertToLLVMCallOpPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + DeallocRawOp::Adaptor transformed(operands); + MemRefDescriptor memref(transformed.memref()); + + Value allocated_bytes_ptr = rewriter.create( + op->getLoc(), getVoidPtrType(), + memref.allocatedPtr(rewriter, op->getLoc())); + + // Insert function call. + FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op); + rewriter.replaceOpWithNewOp( + op, llvm::None, tf_func_ref, + llvm::makeArrayRef({transformed.ctx(), allocated_bytes_ptr})); + return success(); + } + + protected: + StringRef GetFuncName() const override { return kCInterfaceDealloc; } + LLVMType GetFuncType() const override { + return LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), + /*isVarArg=*/false); + } +}; + +class NullContextOpConverter : public ConvertOpToLLVMPattern { + public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, getVoidPtrType()); + return success(); + } +}; + +} // namespace + +void PopulateTFFrameworkToLLVMConversionPatterns( + LLVMTypeConverter *converter, OwningRewritePatternList *patterns) { + patterns->insert(*converter); + patterns->insert(*converter); +} + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc new file mode 100644 index 00000000000..42e89433dff --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm_pass.cc @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" + +namespace mlir { +namespace kernel_gen { +namespace tf_framework { +namespace { + +#define GEN_PASS_CLASSES +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +class TestTFFrameworkToLLVMPass + : public TestTFFrameworkLegalizeToLLVMPassBase { + public: + void runOnOperation() override { + ModuleOp m = getOperation(); + + // Populate type conversions. + LLVMTypeConverter type_converter(m.getContext()); + type_converter.addConversion([&](tf_framework::OpKernelContextType type) { + return LLVM::LLVMType::getInt8PtrTy(m.getContext()); + }); + + // Populate patterns. + OwningRewritePatternList patterns; + populateStdToLLVMConversionPatterns(type_converter, patterns); + PopulateTFFrameworkToLLVMConversionPatterns(&type_converter, &patterns); + lmhlo::PopulateLhloToLLVMConversionPatterns(&type_converter, &patterns); + + // Set target. + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addLegalOp(); + + if (failed(applyFullConversion(m, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr > +createTestTFFrameworkLegalizeToLLVMPass() { + return std::make_unique(); +} + +} // namespace tf_framework +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 838b060079c..71e18af498b 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -48,17 +48,21 @@ cc_library( srcs = [ "transforms/generated_legalize_tf.inc", "transforms/legalize_tf.cc", + "transforms/legalize_tf_communication.cc", "transforms/legalize_tf_control_flow.cc", ], hdrs = [ "transforms/passes.h", ], deps = [ + ":type_to_shape", + ":xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo", "//tensorflow/compiler/mlir/hlo:convert_op_folder", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client:sharding_builder", @@ -92,7 +96,11 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", "//tensorflow/compiler/mlir/tensorflow:translate_utils", - "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_compilation_device", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_expression", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", @@ -123,17 +131,21 @@ cc_library( ":hlo_utils", ":mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", + "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_parser", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Translation", ], alwayslink = 1, ) @@ -203,6 +215,7 @@ tf_cc_test( name = "type_to_shape_test", srcs = ["type_to_shape_test.cc"], deps = [ + ":hlo_utils", ":type_to_shape", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -224,11 +237,11 @@ cc_library( deps = [ ":type_to_shape", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", + "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -315,7 +328,10 @@ cc_library( hdrs = ["xla_mlir_translate.h"], deps = [ ":hlo_to_mlir_hlo", + ":mhlo_to_lhlo_with_xla", ":mlir_hlo_to_hlo", + "//tensorflow/compiler/jit:xla_cpu_jit", + "//tensorflow/compiler/jit:xla_gpu_jit", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", @@ -370,7 +386,7 @@ cc_library( ":xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", + "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/hlo:legalize_control_flow", "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation", diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc index ad177ce1dc5..a63fc12c285 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -521,6 +521,13 @@ StatusOr HloFunctionImporter::ImportInstruction( RandomDistributionToString(instruction->random_distribution()))); } } + case HloOpcode::kRngBitGenerator: { + auto rng_op = Cast(instruction); + auto op = func_builder->create( + loc, result_type, + func_builder->getI32IntegerAttr(rng_op->algorithm()), operands[0]); + return op.getOperation(); + } case HloOpcode::kWhile: { auto op = func_builder->create( loc, operands[0].getType(), operands[0]); @@ -708,6 +715,15 @@ StatusOr HloFunctionImporter::ImportInstruction( NoAttributeCase(kCopy, CopyOp); #undef NoAttributeCase #undef MakeAndReturn + case HloOpcode::kFusion: { + auto fusion = func_builder->create( + loc, result_type, operands, + builder_->getStringAttr(xla::ToString(instruction->fusion_kind()))); + TF_RETURN_IF_ERROR( + ImportAsRegion(*instruction->fused_instructions_computation(), + &fusion.fused_computation())); + return fusion.getOperation(); + } case HloOpcode::kAddDependency: // Arbitrary op code that I suspect we will not implement for quite a // while and allows testing handling of unknown ops. Selected because it diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 84c574139e9..cf78c81908d 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -77,13 +77,14 @@ StatusOr> GetPermutationIfAvailable( return tensorflow::errors::Internal( "Permutations for dynamic shapes are not yet supported"); } - llvm::SmallVector permuted_sizes; - for (auto dim : llvm::reverse(shape.layout().minor_to_major())) { - permuted_sizes.push_back(shape.dimensions(dim)); + int64_t accumulated_stride = 1; + llvm::SmallVector strides(shape.rank(), 1); + for (int64 dim : LayoutUtil::MinorToMajor(shape)) { + strides[dim] = accumulated_stride; + accumulated_stride *= shape.dimensions(dim); } - return llvm::SmallVector{AffineMap::get( - permuted_sizes.size(), 0, - makeCanonicalStridedLayoutExpr(permuted_sizes, builder.getContext()))}; + return llvm::SmallVector{ + makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())}; } } // namespace diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 31512c90f09..c94110d9102 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -206,6 +206,15 @@ XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) { }); } +StatusOr MlirHloBuilder::BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create(loc_, ty, + GetValue(operand)); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation) { TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( @@ -224,6 +233,31 @@ StatusOr MlirHloBuilder::RevInternal( return MakeXlaOp(op); } +StatusOr MlirHloBuilder::SortInternal(const Shape& shape, + absl::Span operands, + const XlaComputation& comparator, + int64 dimension, bool is_stable) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create( + loc_, ty, GetValues(operands), builder_.getI64IntegerAttr(dimension), + builder_.getBoolAttr(is_stable)); + TF_RETURN_IF_ERROR(ImportComputation(comparator.proto(), &op.comparator())); + return MakeXlaOp(op); +} + +StatusOr MlirHloBuilder::WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, + XlaOp init) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create(loc_, ty, GetValue(init)); + TF_RETURN_IF_ERROR(ImportComputation(condition.proto(), &op.cond())); + TF_RETURN_IF_ERROR(ImportComputation(body.proto(), &op.body())); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::GatherInternal( const Shape& shape, XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index ab1a0d2c9b3..a12eb723465 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -142,6 +142,9 @@ class MlirHloBuilder : public XlaBuilder { XlaOp Iota(const Shape& shape, int64 iota_dimension) override; + StatusOr BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand) override; + StatusOr TransposeInternal( const Shape& shape, XlaOp operand, absl::Span permutation) override; @@ -149,6 +152,16 @@ class MlirHloBuilder : public XlaBuilder { StatusOr RevInternal(const Shape& shape, XlaOp operand, absl::Span dimensions) override; + StatusOr SortInternal(const Shape& shape, + absl::Span operands, + const XlaComputation& comparator, + int64 dimension, bool is_stable) override; + + StatusOr WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, + XlaOp init) override; + StatusOr GatherInternal( const Shape& shape, XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index a4c3c43cfbf..5398cd70777 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -43,7 +43,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/quantize.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -90,6 +89,18 @@ T* Unwrap(const std::unique_ptr& t) { return t.get(); } +static mlir::LogicalResult GetXlaOp( + mlir::Value val, const llvm::DenseMap& val_map, + xla::XlaOp* result, mlir::Operation* op) { + auto iter = val_map.find(val); + if (iter == val_map.end()) { + return op->emitOpError( + "requires all operands to be defined in the parent region for export"); + } + *result = iter->second; + return mlir::success(); +} + // Convert APInt into an int. // TODO(hpucha): This should be consolidated into a general place. static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); } @@ -170,8 +181,8 @@ static std::vector> Convert_source_target_pairs( static std::vector Convert_replica_groups( mlir::DenseIntElementsAttr groups) { - int64_t num_groups = groups.getType().getDimSize(0); - int64_t group_size = groups.getType().getDimSize(1); + uint64_t num_groups = groups.getType().getDimSize(0); + uint64_t group_size = groups.getType().getDimSize(1); std::vector result; result.reserve(num_groups); @@ -435,14 +446,14 @@ static void ExtractShardingsFromFunction( llvm::SmallVectorImpl>* ret_shardings) { arg_shardings->resize(function.getNumArguments(), absl::optional()); - for (int i = 0; i < function.getNumArguments(); ++i) + for (int i = 0, end = function.getNumArguments(); i < end; ++i) if (auto sharding = function.getArgAttrOfType(i, kShardingAttr)) (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); ret_shardings->resize(function.getNumResults(), absl::optional()); - for (int i = 0; i < function.getNumResults(); ++i) + for (int i = 0, end = function.getNumResults(); i < end; ++i) if (auto sharding = function.getResultAttrOfType(i, kShardingAttr)) (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); @@ -463,7 +474,7 @@ class ConvertToHloModule { // single value. explicit ConvertToHloModule( mlir::ModuleOp module, bool use_tuple_args, bool return_tuple, - tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn) + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) : module_(module), module_builder_("main"), use_tuple_args_(use_tuple_args), @@ -507,7 +518,7 @@ class ConvertToHloModule { // Lower function call to HLO call instruction LogicalResult LowerFunctionCall( - mlir::CallOp* call_op, xla::XlaBuilder* builder, + mlir::CallOp call_op, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering); private: @@ -545,7 +556,7 @@ class ConvertToHloModule { // Shape representation function to determine entry function argument and // result shapes. - tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_; // Unique suffix to give to the name of the next lowered region. size_t region_id_ = 0; @@ -585,23 +596,27 @@ LogicalResult ExportXlaOp(AllReduceOp op, OpLoweringContext ctx) { return failure(); } auto replica_groups = Convert_replica_groups(op.replica_groups()); + xla::XlaOp operand; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + if (!op.channel_id().hasValue()) { - value_map[op] = - xla::AllReduce(value_map[op.operand()], computation, replica_groups, - /*channel_id=*/absl::nullopt); + value_map[op] = xla::AllReduce(operand, computation, replica_groups, + /*channel_id=*/absl::nullopt); return success(); } auto channel_id = Convert_channel_handle(op.channel_id().getValue()); - value_map[op] = xla::AllReduce(value_map[op.operand()], computation, - replica_groups, channel_id); + value_map[op] = + xla::AllReduce(operand, computation, replica_groups, channel_id); return success(); } LogicalResult ExportXlaOp(BitcastConvertOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; + xla::XlaOp operand; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + value_map[op] = xla::BitcastConvertType( - value_map[op.operand()], - xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); + operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); return success(); } @@ -609,8 +624,11 @@ LogicalResult ExportXlaOp(BroadcastInDimOp op, OpLoweringContext ctx) { auto type = op.getType().dyn_cast(); if (!type) return failure(); auto& value_map = *ctx.values; + xla::XlaOp operand; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + value_map[op] = - BroadcastInDim(value_map[op.operand()], Convert_ArrayRef(type.getShape()), + BroadcastInDim(operand, Convert_ArrayRef(type.getShape()), Convert_broadcast_dimensions(op.broadcast_dimensions())); return success(); } @@ -640,11 +658,15 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { &false_branch))) { return failure(); } + xla::XlaOp pred, true_arg, false_arg; + if (failed(GetXlaOp(op.pred(), value_map, &pred, op))) return failure(); + if (failed(GetXlaOp(op.true_arg(), value_map, &true_arg, op))) + return failure(); + if (failed(GetXlaOp(op.false_arg(), value_map, &false_arg, op))) + return failure(); value_map[op] = - xla::Conditional(value_map[op.pred()], value_map[op.true_arg()], - true_branch, value_map[op.false_arg()], false_branch); - + xla::Conditional(pred, true_arg, true_branch, false_arg, false_branch); return success(); } @@ -657,14 +679,19 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { std::vector computations_p(branches.size()); for (unsigned i = 0; i < branches.size(); ++i) { - branch_operands[i] = value_map[operands[i]]; + xla::XlaOp operand; + if (failed(GetXlaOp(operands[i], value_map, &operand, op))) + return failure(); + branch_operands[i] = operand; computations_p[i] = &computations[i]; if (failed(ctx.converter->LowerRegionAsComputation(&branches[i], computations_p[i]))) return failure(); } - xla::XlaOp result = - xla::Conditional(value_map[op.index()], computations_p, branch_operands); + xla::XlaOp index; + if (failed(GetXlaOp(op.index(), value_map, &index, op))) return failure(); + + xla::XlaOp result = xla::Conditional(index, computations_p, branch_operands); if (op.getNumResults() == 1) { value_map[op.getResult(0)] = result; } else { @@ -681,9 +708,11 @@ LogicalResult ExportXlaOp(ConstOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(ConvertOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; + xla::XlaOp operand; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + value_map[op] = xla::ConvertElementType( - value_map[op.operand()], - xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); + operand, xla::TypeToPrimitiveType(getElementTypeOrSelf(op.getType()))); return success(); } @@ -702,7 +731,10 @@ LogicalResult ExportXlaOp(DequantizeOp op, OpLoweringContext ctx) { xla::QuantizedRange range(ConvertAPFloat(op.min_range()), ConvertAPFloat(op.max_range())); auto& value_map = *ctx.values; - auto casted = xla::ConvertElementType(value_map[op.input()], xla::U32); + xla::XlaOp input; + if (failed(GetXlaOp(op.input(), value_map, &input, op))) return failure(); + + auto casted = xla::ConvertElementType(input, xla::U32); if (op.is_16bits()) { value_map[op] = xla::Dequantize( casted, range, ConvertStringRef(op.mode()), op.transpose_output()); @@ -715,12 +747,14 @@ LogicalResult ExportXlaOp(DequantizeOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(InfeedOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; + xla::XlaOp token; + if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure(); + // The shape argument expected by the xla client API is the type of the first // element in the result tuple. auto result_type = op.getType().cast().getType(0); - value_map[op] = - xla::InfeedWithToken(value_map[op.token()], xla::TypeToShape(result_type), - std::string(op.infeed_config())); + value_map[op] = xla::InfeedWithToken(token, xla::TypeToShape(result_type), + std::string(op.infeed_config())); return success(); } @@ -745,10 +779,13 @@ LogicalResult ExportXlaOp(MapOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(OutfeedOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - value_map[op] = - xla::OutfeedWithToken(value_map[op.operand()], value_map[op.token()], - xla::TypeToShape(op.operand().getType()), - std::string(op.outfeed_config())); + xla::XlaOp operand, token; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure(); + + value_map[op] = xla::OutfeedWithToken( + operand, token, xla::TypeToShape(op.operand().getType()), + std::string(op.outfeed_config())); return success(); } @@ -758,29 +795,34 @@ LogicalResult ExportXlaOp(PadOp op, OpLoweringContext ctx) { auto edge_padding_low = ConvertDenseIntAttr(op.edge_padding_low()); auto edge_padding_high = ConvertDenseIntAttr(op.edge_padding_high()); auto interior_padding = ConvertDenseIntAttr(op.interior_padding()); - for (xla::int64 i = 0; i < edge_padding_low.size(); ++i) { + for (xla::int64 i = 0, end = edge_padding_low.size(); i < end; ++i) { auto* dims = padding_config.add_dimensions(); dims->set_edge_padding_low(edge_padding_low[i]); dims->set_edge_padding_high(edge_padding_high[i]); dims->set_interior_padding(interior_padding[i]); } - value_map[op] = xla::Pad(value_map[op.getOperand(0)], - value_map[op.getOperand(1)], padding_config); + xla::XlaOp operand, padding_value; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + if (failed(GetXlaOp(op.padding_value(), value_map, &padding_value, op))) + return failure(); + + value_map[op] = xla::Pad(operand, padding_value, padding_config); return success(); } LogicalResult ExportXlaOp(RecvOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; auto result_type = op.getType().cast().getType(0); + xla::XlaOp token; + if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure(); + if (op.is_host_transfer()) { - value_map[op] = - xla::RecvFromHost(value_map[op.token()], xla::TypeToShape(result_type), - Convert_channel_handle(op.channel_id())); + value_map[op] = xla::RecvFromHost(token, xla::TypeToShape(result_type), + Convert_channel_handle(op.channel_id())); return success(); } - value_map[op] = - xla::RecvWithToken(value_map[op.token()], xla::TypeToShape(result_type), - Convert_channel_handle(op.channel_id())); + value_map[op] = xla::RecvWithToken(token, xla::TypeToShape(result_type), + Convert_channel_handle(op.channel_id())); return success(); } @@ -810,9 +852,13 @@ LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) { if (failed(ctx.converter->LowerRegionAsComputation(&op.body(), &body))) { return failure(); } + xla::XlaOp operand, init_value; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op))) + return failure(); + value_map[op] = xla::ReduceWindowWithGeneralPadding( - value_map[op.operand()], value_map[op.init_value()], body, - ConvertDenseIntAttr(op.window_dimensions()), + operand, init_value, body, ConvertDenseIntAttr(op.window_dimensions()), ConvertDenseIntAttr(op.window_strides()), ConvertDenseIntAttr(op.base_dilations()), ConvertDenseIntAttr(op.window_dilations()), @@ -822,9 +868,11 @@ LogicalResult ExportXlaOp(ReduceWindowOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(ReshapeOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - value_map[op] = xla::Reshape(value_map[op.operand()], - xla::TypeToShape(op.getType()).dimensions()); + xla::XlaOp operand; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + value_map[op] = + xla::Reshape(operand, xla::TypeToShape(op.getType()).dimensions()); return success(); } @@ -834,17 +882,34 @@ LogicalResult ExportXlaOp(ReturnOp op, OpLoweringContext ctx) { return failure(); } +LogicalResult ExportXlaOp(RngBitGeneratorOp op, OpLoweringContext ctx) { + auto& value_map = *ctx.values; + auto result = op.getResult(); + auto xla_arg_1 = value_map[*op.getODSOperands(0).begin()]; + auto xla_result = xla::RngBitGenerator( + static_cast(op.rng_algorithm().getSExtValue()), + Unwrap(xla_arg_1), xla::TypeToShape(result.getType()).tuple_shapes(1)); + value_map[result] = xla_result; + return mlir::success(); +} + LogicalResult ExportXlaOp(RngNormalOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - value_map[op] = xla::RngNormal(value_map[op.mu()], value_map[op.sigma()], - xla::TypeToShape(op.getType())); + xla::XlaOp mu, sigma; + if (failed(GetXlaOp(op.mu(), value_map, &mu, op))) return failure(); + if (failed(GetXlaOp(op.sigma(), value_map, &sigma, op))) return failure(); + + value_map[op] = xla::RngNormal(mu, sigma, xla::TypeToShape(op.getType())); return success(); } LogicalResult ExportXlaOp(RngUniformOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - value_map[op] = xla::RngUniform(value_map[op.a()], value_map[op.b()], - xla::TypeToShape(op.getType())); + xla::XlaOp a, b; + if (failed(GetXlaOp(op.a(), value_map, &a, op))) return failure(); + if (failed(GetXlaOp(op.b(), value_map, &b, op))) return failure(); + + value_map[op] = xla::RngUniform(a, b, xla::TypeToShape(op.getType())); return success(); } @@ -857,10 +922,15 @@ LogicalResult ExportXlaOp(ScatterOp op, OpLoweringContext ctx) { } xla::ScatterDimensionNumbers dimension_numbers = Convert_scatter_dimension_numbers(op.scatter_dimension_numbers()); - value_map[op] = xla::Scatter( - value_map[op.operand()], value_map[op.scatter_indices()], - value_map[op.updates()], update_computation, dimension_numbers, - op.indices_are_sorted(), op.unique_indices()); + xla::XlaOp operand, scatter_indices, updates; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + if (failed(GetXlaOp(op.scatter_indices(), value_map, &scatter_indices, op))) + return failure(); + if (failed(GetXlaOp(op.updates(), value_map, &updates, op))) return failure(); + + value_map[op] = xla::Scatter(operand, scatter_indices, updates, + update_computation, dimension_numbers, + op.indices_are_sorted(), op.unique_indices()); return success(); } @@ -873,26 +943,33 @@ LogicalResult ExportXlaOp(SelectAndScatterOp op, OpLoweringContext ctx) { ctx.converter->LowerRegionAsComputation(&op.scatter(), &scatter))) { return failure(); } + xla::XlaOp operand, source, init_value; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + if (failed(GetXlaOp(op.source(), value_map, &source, op))) return failure(); + if (failed(GetXlaOp(op.init_value(), value_map, &init_value, op))) + return failure(); + value_map[op] = xla::SelectAndScatterWithGeneralPadding( - value_map[op.operand()], select, - ConvertDenseIntAttr(op.window_dimensions()), + operand, select, ConvertDenseIntAttr(op.window_dimensions()), ConvertDenseIntAttr(op.window_strides()), Convert_padding(op.padding()), - value_map[op.source()], value_map[op.init_value()], scatter); + source, init_value, scatter); return success(); } LogicalResult ExportXlaOp(SendOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; + xla::XlaOp operand, token; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + if (failed(GetXlaOp(op.token(), value_map, &token, op))) return failure(); + if (op.is_host_transfer()) { - value_map[op] = - xla::SendToHost(value_map[op.operand()], value_map[op.token()], - xla::TypeToShape(op.operand().getType()), - Convert_channel_handle(op.channel_id())); + value_map[op] = xla::SendToHost(operand, token, + xla::TypeToShape(op.operand().getType()), + Convert_channel_handle(op.channel_id())); return success(); } - value_map[op] = - xla::SendWithToken(value_map[op.operand()], value_map[op.token()], - Convert_channel_handle(op.channel_id())); + value_map[op] = xla::SendWithToken(operand, token, + Convert_channel_handle(op.channel_id())); return success(); } @@ -914,7 +991,9 @@ LogicalResult ExportXlaOp(SortOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; - xla::Trace(std::string(op.tag()), value_map[op.operand()]); + xla::XlaOp operand; + if (failed(GetXlaOp(op.operand(), value_map, &operand, op))) return failure(); + xla::Trace(std::string(op.tag()), operand); return success(); } @@ -933,13 +1012,40 @@ LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) { return failure(); } - value_map[op] = xla::While(condition, body, value_map[op.getOperand()]); + xla::XlaOp operand; + if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) + return failure(); + value_map[op] = xla::While(condition, body, operand); return success(); } LogicalResult ExportXlaOp(FusionOp op, OpLoweringContext ctx) { - // TODO(whoever): currently not supported. - return failure(); + if (!op.fusion_kind()) { + op.emitOpError() << "requires fusion kind for HLO translation"; + return failure(); + } + + xla::XlaComputation fused_computation; + if (failed(ctx.converter->LowerRegionAsComputation(&op.fused_computation(), + &fused_computation))) + return failure(); + + auto& values = *ctx.values; + llvm::SmallVector operands; + for (auto operand : op.operands()) operands.push_back(values[operand]); + + xla::XlaOp fusion = xla::internal::XlaBuilderBuildFusion( + ctx.builder, operands, + absl::string_view(op.fusion_kind()->data(), op.fusion_kind()->size()), + fused_computation); + if (op.getNumResults() == 1) { + values[op.getResult(0)] = fusion; + } else { + for (auto item : llvm::enumerate(op.getResults())) { + values[item.value()] = xla::GetTupleElement(fusion, item.index()); + } + } + return success(); } } // namespace @@ -1032,7 +1138,7 @@ LogicalResult ConvertToHloModule::Lower( ElementsAttr const_attr; if (auto call_op = dyn_cast(inst)) { - return LowerFunctionCall(&call_op, builder, &value_map); + return LowerFunctionCall(call_op, builder, &value_map); } if (auto op = dyn_cast(inst)) { @@ -1046,7 +1152,10 @@ LogicalResult ConvertToHloModule::Lower( return failure(); } - value_map[op.getResult()] = value_map[operand]; + xla::XlaOp xla_operand; + if (failed(GetXlaOp(operand, value_map, &xla_operand, op))) + return failure(); + value_map[op.getResult()] = xla_operand; return success(); } @@ -1072,7 +1181,11 @@ LogicalResult ConvertToHloModule::Lower( std::vector returns(num_return_values); for (OpOperand& ret : inst->getOpOperands()) { unsigned index = ret.getOperandNumber(); - returns[index] = value_map[ret.get()]; + xla::XlaOp operand; + if (failed(GetXlaOp(ret.get(), value_map, &operand, inst))) + return failure(); + + returns[index] = operand; if (!is_entry_function || !has_ret_shardings) continue; xla::Shape return_shape = xla::TypeToShape(ret.get().getType()); @@ -1098,7 +1211,11 @@ LogicalResult ConvertToHloModule::Lower( return_value = xla::Tuple(builder, returns); builder->ClearSharding(); } else if (num_return_values == 1) { - return_value = value_map[inst->getOperand(0)]; + xla::XlaOp operand; + if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst))) + return failure(); + + return_value = operand; } // Build the XlaComputation and check for failures. @@ -1117,14 +1234,17 @@ LogicalResult ConvertToHloModule::Lower( } LogicalResult ConvertToHloModule::LowerFunctionCall( - mlir::CallOp* call_op, xla::XlaBuilder* builder, + mlir::CallOp call_op, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering) { auto& value_map = *value_lowering; - mlir::FuncOp callee = module_.lookupSymbol(call_op->callee()); + mlir::FuncOp callee = module_.lookupSymbol(call_op.callee()); if (failed(RunOnFunction(callee))) return failure(); std::vector operands; - for (auto operand : call_op->getOperands()) { - operands.push_back(value_map[operand]); + for (auto operand : call_op.getOperands()) { + xla::XlaOp xla_operand; + if (failed(GetXlaOp(operand, value_map, &xla_operand, call_op))) + return failure(); + operands.push_back(xla_operand); } // Each call to xla::Call would insert a copy of the computation to // the HLO. Thus each callsite would have a unique callee in the @@ -1135,13 +1255,13 @@ LogicalResult ConvertToHloModule::LowerFunctionCall( xla::XlaOp call_result = xla::Call(builder, lowered_computation_[callee], operands); // Use GetTupleElement for multiple outputs - unsigned num_results = call_op->getNumResults(); + unsigned num_results = call_op.getNumResults(); if (num_results > 1) { for (unsigned i = 0; i != num_results; ++i) { - value_map[call_op->getResult(i)] = xla::GetTupleElement(call_result, i); + value_map[call_op.getResult(i)] = xla::GetTupleElement(call_result, i); } } else if (num_results == 1) { - value_map[call_op->getResult(0)] = call_result; + value_map[call_op.getResult(0)] = call_result; } return success(); } @@ -1271,8 +1391,7 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( llvm::ArrayRef> arg_shardings, llvm::ArrayRef> ret_shardings, xla::XlaComputation* result) { - // Mapping from the Value to lowered XlaOp. The code below lowers in - // program order and will fail if an operand is unseen. This can be improved. + // Mapping from the Value to lowered XlaOp. ValueLoweringMap lowering; // If using tuples as input, then there is only one input parameter that is a @@ -1498,9 +1617,19 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, } // namespace +Status ConvertRegionToComputation(mlir::Region* region, + xla::XlaComputation* func) { + mlir::ModuleOp module; + ConvertToHloModule converter(module, true, true, {}); + if (failed(converter.LowerRegionAsComputation(region, func))) + return tensorflow::errors::Internal( + "failed to convert region to computation"); + return Status::OK(); +} + Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args, bool return_tuple, - const tensorflow::XlaCompiler::ShapeRepresentationFn + const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); ConvertToHloModule converter(module, use_tuple_args, return_tuple, diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 8bfe4c76b04..6f2b5a6db95 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -18,9 +18,10 @@ limitations under the License. #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace mlir { @@ -33,9 +34,14 @@ namespace mlir { // single value. Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto, bool use_tuple_args, bool return_tuple, - const tensorflow::XlaCompiler::ShapeRepresentationFn + const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr); +// Converts a region to a computation. It returns a standalone module that +// contains the converted region as the entry computation. +Status ConvertRegionToComputation(mlir::Region* region, + ::xla::XlaComputation* func); + // Creates XlaOp equivalent of a given MLIR operation using the operand info // from `value_lowering` map. llvm::Optional<::xla::XlaOp> CreateXlaOperator( diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc index 108544d96ff..407a7d3da38 100644 --- a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -85,18 +85,25 @@ static void BuildOperator(const Operator& op, raw_ostream& os) { // Emit an argument for an operand. if (auto* operand_cst = arg.dyn_cast()) { + std::string xla_arg = "xla_arg_" + std::to_string(index); // Handle a non-variadic operand. if (!operand_cst->isVariableLength()) { - os << " auto xla_arg_" << index << " = value_map[*op.getODSOperands(" - << operand_number++ << ").begin()];\n"; + os << " xla::XlaOp " << xla_arg << ";\n"; + os << " if (failed(GetXlaOp(*op.getODSOperands(" << operand_number++ + << ").begin(), value_map, &" << xla_arg << ", op)))\n"; + os << " return mlir::failure();\n"; continue; } // Otherwise, this is a varidiac operand list. - os << " std::vector xla_arg_" << index << ";\n" + os << " std::vector " << xla_arg << ";\n" << " for (auto operand : op.getODSOperands(" << operand_number++ - << "))\n xla_arg_" << index - << ".push_back(value_map[operand]);\n"; + << ")) {\n"; + os << " xla::XlaOp result;\n"; + os << " if (failed(GetXlaOp(operand, value_map, &result, op)))\n"; + os << " return mlir::failure();\n"; + os << " " << xla_arg << ".push_back(result);\n"; + os << " }\n"; continue; } diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD index de2f6669339..2631e2b6757 100644 --- a/tensorflow/compiler/mlir/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/BUILD @@ -6,7 +6,10 @@ package(licenses = ["notice"]) glob_lit_tests( data = [":test_utilities"], driver = "@llvm-project//mlir:run_lit.sh", - test_file_exts = ["mlir"], + test_file_exts = [ + "mlir", + "hlotxt", + ], ) # Bundle together all of the test utilities that are used by tests. @@ -14,6 +17,7 @@ filegroup( name = "test_utilities", testonly = True, data = [ + "//tensorflow/compiler/mlir:tf-mlir-translate", "//tensorflow/compiler/mlir:tf-opt", "//tensorflow/compiler/mlir/xla:xla-opt", "@llvm-project//llvm:FileCheck", diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt new file mode 100644 index 00000000000..3630d2d45e4 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt @@ -0,0 +1,13 @@ +// RUN: tf-mlir-translate -hlo-text-to-lhlo %s | FileCheck %s + +HloModule TestModule + +// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> + +// CHECK: func @TestComputation +ENTRY TestComputation { + x = f32[3, 2]{1,0} parameter(0) + + // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> () + ROOT x.copy = f32[3, 2]{0,1} copy(x) +} diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir index 09a85177fae..2e1b63b0db7 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir @@ -1,7 +1,7 @@ -// RUN: xla-opt -split-input-file -xla-hlo-to-lhlo-with-xla %s | FileCheck --enable-var-scope %s +// RUN: xla-opt -split-input-file -xla-hlo-to-lhlo-with-xla %s | FILECHECK_OPTS="" FileCheck --enable-var-scope %s // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -14,8 +14,8 @@ func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -29,8 +29,8 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> @@ -44,7 +44,7 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -57,8 +57,8 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {lmhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcomplex> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex> @@ -72,7 +72,7 @@ func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcom // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xcomplex> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex> @@ -86,8 +86,8 @@ func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xcomplex> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -101,7 +101,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -114,7 +114,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -127,8 +127,8 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -142,8 +142,8 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -157,8 +157,8 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -172,7 +172,7 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -185,7 +185,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<8xi8> func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32> @@ -198,7 +198,7 @@ func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<8xi8> func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32> @@ -211,8 +211,8 @@ func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> @@ -226,7 +226,7 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -239,9 +239,9 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 -// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {lmhlo.params = 2 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 2 // CHECK-SAME: %[[ARG3:.*]]: memref<16xi8> func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -255,7 +255,7 @@ func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -268,7 +268,7 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> @@ -281,8 +281,8 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> @@ -296,7 +296,7 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = {{[0-9]+}} : index, lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir index cc07624d63d..88e5e1e0a32 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir @@ -1,9 +1,9 @@ -// RUN: xla-opt -xla-hlo-to-lhlo-with-xla %s | FileCheck --enable-var-scope %s +// RUN: xla-opt -xla-hlo-to-lhlo-with-xla %s | FILECHECK_OPTS="" FileCheck --enable-var-scope %s // Current allocation will lead to one buffer argument for the "value" and // another one for the output, an no returned values. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 : index}, +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = 1 : index, lmhlo.params = 0 : index}, // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {lmhlo.alloc = 0 : index, lmhlo.liveout = true} // CHECK-SAME: ) { func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir index de03921f091..69eaeeb946d 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -10,14 +10,14 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> // CHECK: [[LHSSHAPE:%.*]] = shape.shape_of [[LHS]] : tensor<1x4x2xf32> // CHECK: [[RHSSHAPE:%.*]] = shape.shape_of [[RHS]] : tensor<3x2x4xf32> // CHECK: [[CM2:%.*]] = constant -2 : i32 -// CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) -// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) -// CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape -// CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape -// CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] : tensor<3xindex> +// CHECK: [[LHSHEAD:%.*]], [[LHSTAIL:%.*]] = "shape.split_at"([[LHSSHAPE]], [[CM2]]) +// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) +// CHECK: [[BCASTHEAD:%.*]] = shape.broadcast [[LHSHEAD]], [[RHSHEAD]] +// CHECK: [[LHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[LHSTAIL]] +// CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] // CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> -// CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape -// CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]] : tensor<3xindex> +// CHECK: [[RHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[RHSTAIL]] +// CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]] // CHECK: [[RHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32> // CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> // CHECK: return [[RESULT]] : tensor<3x4x4xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir index 45c90d26ab4..5f3e40f923f 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-binary-elementwise.mlir @@ -48,8 +48,8 @@ func @add_dynamic(%arg0: tensor, %arg1: tensor) -> tensor, tensor -> tensor + // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_SHAPE]] : tensor to tensor<2xindex> // CHECK-NEXT: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-NEXT: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] : tensor @@ -201,8 +201,8 @@ func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor // NOT-CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[LHS_SHAPE]], %[[RHS_SHAPE]] // NOT-CHECK-NEXT: shape.assuming %[[WITNESS]] -> (tensor) { // NOT-CHECK-DAG: %[[LHS_SHAPE1:.+]] = shape.shape_of %arg0 - // NOT-CHECK-NEXT: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE1]], %[[RHS_SHAPE]]) - // NOT-CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] + // NOT-CHECK-NEXT: %[[RESULT_SHAPE:.+]] = shape.broadcast %[[LHS_SHAPE1]], %[[RHS_SHAPE]] : tensor, tensor -> tensor + // NOT-CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_SHAPE]] : tensor to tensor<1xindex> // NOT-CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // NOT-CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // NOT-CHECK-NEXT: %[[RESULT:.+]] = "mhlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} @@ -290,8 +290,8 @@ func @greater_dynamic(%arg0: tensor, %arg1: tensor) -> tensor, tensor -> tensor + // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = tensor_cast %[[RESULT_SHAPE]] : tensor to tensor<1xindex> // CHECK-DAG: %[[LHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK-DAG: %[[RHS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK-NEXT: "mhlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir new file mode 100644 index 00000000000..550b2ba4da3 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir @@ -0,0 +1,1107 @@ +// RUN: tf-opt -split-input-file -verify-diagnostics -xla-legalize-tf-communication %s | FileCheck %s + +// Test legalization of `tf._XlaHostComputeMlir` expands into individual +// `mhlo.send` per operand and `mhlo.recv` per result. Channel Id's are uniquely +// assigned per mhlo communcation op, and frontend attributes (modified keys) +// and op shardings (based on `tpu_core`) are added. Sink tokens are created +// if there are more than one operand or more than one result. +// +// The following op sharding is used: +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\01\1A\01\01\22\01\00" + +// CHECK-LABEL: func @host_compute +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor) +func @host_compute(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[SEND_ARG0_TOKEN:%.*]] = "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "host_compute_channel_send_dtoh_0"} + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" + // CHECK-SAME: (tensor, !mhlo.token) -> !mhlo.token + + // CHECK: [[SEND_ARG1_TOKEN:%.*]] = "mhlo.send"([[ARG1]], [[INIT_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 2 : i64} + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s64", _xla_host_transfer_rendezvous = "host_compute_channel_send_dtoh_1"} + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" + // CHECK-SAME: (tensor, !mhlo.token) -> !mhlo.token + + // CHECK: [[SEND_SINK_TOKEN:%.*]] = "mhlo.after_all"([[SEND_ARG0_TOKEN]], [[SEND_ARG1_TOKEN]]) + + // CHECK: [[RECV_RETVAL0_TUPLE:%.*]] = "mhlo.recv"([[SEND_SINK_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 3 : i64, type = 3 : i64} + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "host_compute_channel_recv_htod_0"} + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" + // CHECK-SAME: (!mhlo.token) -> tuple, !mhlo.token> + + // CHECK: [[RECV_RETVAL0_VAL:%.*]] = "mhlo.get_tuple_element"([[RECV_RETVAL0_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" + // CHECK-SAME: (tuple, !mhlo.token>) -> tensor + + // CHECK: [[RECV_RETVAL0_TOKEN:%.*]] = "mhlo.get_tuple_element"([[RECV_RETVAL0_TUPLE]]) + // CHECK-SAME: index = 1 + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" + // CHECK-SAME: (tuple, !mhlo.token>) -> !mhlo.token + + // CHECK: [[RECV_RETVAL1_TUPLE:%.*]] = "mhlo.recv"([[SEND_SINK_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 4 : i64, type = 3 : i64} + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f64", _xla_host_transfer_rendezvous = "host_compute_channel_recv_htod_1"} + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" + // CHECK-SAME: (!mhlo.token) -> tuple, !mhlo.token> + + // CHECK: [[RECV_RETVAL1_VAL:%.*]] = "mhlo.get_tuple_element"([[RECV_RETVAL1_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" + // CHECK-SAME: (tuple, !mhlo.token>) -> tensor + + // CHECK: [[RECV_RETVAL1_TOKEN:%.*]] = "mhlo.get_tuple_element"([[RECV_RETVAL1_TUPLE]]) + // CHECK-SAME: index = 1 + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\00" + // CHECK-SAME: (tuple, !mhlo.token>) -> !mhlo.token + + // CHECK: [[RECV_SINK_TOKEN:%.*]] = "mhlo.after_all"([[RECV_RETVAL0_TOKEN]], [[RECV_RETVAL1_TOKEN]]) + %0:2 = "tf._XlaHostComputeMlir"(%arg0, %arg1) {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0 : i64} : (tensor, tensor) -> (tensor, tensor) + + // CHECK: return [[RECV_RETVAL0_VAL]], [[RECV_RETVAL1_VAL]] : tensor, tensor + return %0#0, %0#1 : tensor, tensor +} + +// ----- + +// Tests `tf._XlaHostComputeMlir` with `tpu_core` assigns the correct op +// sharding. +// +// The following op sharding is used: +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + +// CHECK-LABEL: func @host_compute_sharding +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @host_compute_sharding(%arg0: tensor) -> tensor { + // CHECK: "mhlo.send" + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\01" + // CHECK: "mhlo.recv" + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\01" + // CHECK: "mhlo.get_tuple_element" + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\01" + // CHECK: "mhlo.get_tuple_element" + // CHECK-SAME: mhlo.sharding = "\08\01\1A\01\01\22\01\01" + %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 1 : i64} : (tensor) -> tensor + return %0 : tensor +} + +// ----- + +// Tests `tf._XlaHostComputeMlir` with no operands simply forwards the input +// token to its generated `mhlo.recv`. + +// CHECK-LABEL: func @host_compute_no_operands_one_result +func @host_compute_no_operands_one_result() { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK-NOT: "mhlo.send" + // CHECK-NOT: "mhlo.after_all" + // CHECK: "mhlo.recv"([[INIT_TOKEN]]) + %0 = "tf._XlaHostComputeMlir"() {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0 : i64} : () -> tensor + return +} + +// ----- + +// Tests `tf._XlaHostComputeMlir` with no results simply forwards its token from +// the generated `mhlo.send`. + +// CHECK-LABEL: func @host_compute_one_operand_no_results +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @host_compute_one_operand_no_results(%arg0: tensor) { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[SEND_TOKEN:%.*]] = "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) + // CHECK-NOT: "mhlo.after_all" + "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0 : i64} : (tensor) -> () + + // CHECK: "mhlo.recv"([[SEND_TOKEN]]) + %0 = "tf.XlaRecvFromHost"() {key = "recv_key", shape = #tf.shape<>} : () -> tensor + return +} + +// ----- + +// Tests `tf._XlaHostComputeMlir` with one operand and one result does not +// create any `mhlo.after_all` ops. + +// CHECK-LABEL: func @host_compute_single_operand_result +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @host_compute_single_operand_result(%arg0: tensor) { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[SEND_TOKEN:%.*]] = "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) + // CHECK-NOT: "mhlo.after_all" + // CHECK: "mhlo.recv"([[SEND_TOKEN]]) + // CHECK-NOT: "mhlo.after_all" + %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0 : i64} : (tensor) -> tensor + return +} + +// ----- + +// Test legalization of `tf.XlaSendToHost` expands into a `mhlo.send` op. + +// CHECK-LABEL: func @send_to_host +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @send_to_host(%arg0: tensor) { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send_key"} + // CHECK-SAME: (tensor, !mhlo.token) -> !mhlo.token + "tf.XlaSendToHost"(%arg0) {key = "send_key"} : (tensor) -> () + return +} + +// ----- + +// Test legalization of `tf.XlaRecvFromHost` expands into a `mhlo.recv` op. + +// CHECK-LABEL: func @recv_from_host +func @recv_from_host() -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[RECV_TUPLE:%.*]] = "mhlo.recv"([[INIT_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 3 : i64} + // CHECK-SAME: is_host_transfer = true + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv_key"} + // CHECK-SAME: (!mhlo.token) -> tuple, !mhlo.token> + + + // CHECK: [[RECV_VAL:%.*]] = "mhlo.get_tuple_element"([[RECV_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK-SAME: (tuple, !mhlo.token>) -> tensor + + // CHECK: [[RECV_TOKEN:%.*]] = "mhlo.get_tuple_element"([[RECV_TUPLE]]) + // CHECK-SAME: index = 1 + // CHECK-SAME: (tuple, !mhlo.token>) -> !mhlo.token + %0 = "tf.XlaRecvFromHost"() {key = "recv_key", shape = #tf.shape<>} : () -> tensor + + // CHECK: return [[RECV_VAL]] : tensor + return %0 : tensor +} + +// ----- + +// Test legalization of multiple TF/XLA communication ops are sequenced with +// their generated tokens. Channel Id's are also uniquely assigned. + +// CHECK-LABEL: func @multiple_consecutive_ops +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @multiple_consecutive_ops(%arg0: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[SEND0_ARG0_TOKEN:%.*]] = "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send0_dtoh_0"} + + // CHECK: [[RECV0_RETVAL0_TUPLE:%.*]] = "mhlo.recv"([[SEND0_ARG0_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv0_htod_0"} + + // CHECK: [[RECV0_RETVAL0_VAL:%.*]] = "mhlo.get_tuple_element"([[RECV0_RETVAL0_TUPLE]]) + // CHECK-SAME: index = 0 + + // CHECK: [[RECV0_RETVAL0_TOKEN:%.*]] = "mhlo.get_tuple_element"([[RECV0_RETVAL0_TUPLE]]) + // CHECK-SAME: index = 1 + %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "recv0", send_key = "send0", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK: [[SEND1_ARG0_TOKEN:%.*]] = "mhlo.send"([[RECV0_RETVAL0_VAL]], [[RECV0_RETVAL0_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 3 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send1_dtoh_0"} + + // CHECK: [[RECV1_RETVAL0_TUPLE:%.*]] = "mhlo.recv"([[SEND1_ARG0_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 4 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv1_htod_0"} + + // CHECK: [[RECV1_RETVAL0_VAL:%.*]] = "mhlo.get_tuple_element"([[RECV1_RETVAL0_TUPLE]]) + // CHECK-SAME: index = 0 + + // CHECK: [[RECV1_RETVAL0_TOKEN:%.*]] = "mhlo.get_tuple_element"([[RECV1_RETVAL0_TUPLE]]) + // CHECK-SAME: index = 1 + %1 = "tf._XlaHostComputeMlir"(%0) {recv_key = "recv1", send_key = "send1", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK: return [[RECV1_RETVAL0_VAL]] : tensor + return %1 : tensor +} + +// ----- + +// Test private function with TF/XLA communication op being called by another +// function gets rewritten with an extra token argument and an extra token +// result, and the caller passes in a token. The top level function not called +// (or public) will be updated to create a token. + +// CHECK: func @main([[MAIN_ARG0:%.*]]: tensor) -> tensor +func @main(%arg0: tensor) -> tensor { + // CHECK: [[MAIN_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[CALL:%.*]]:2 = call @callee([[MAIN_ARG0]], [[MAIN_TOKEN]]) + // CHECK-SAME: (tensor, !mhlo.token) -> (tensor, !mhlo.token) + %0 = call @callee(%arg0) : (tensor) -> tensor + + // CHECK: return [[CALL]]#0 + return %0 : tensor +} + +// CHECK: func @callee([[CALLEE_ARG0:%.*]]: tensor, [[CALLEE_ARG1:%.*]]: !mhlo.token) -> (tensor, !mhlo.token) +func @callee(%arg0: tensor) -> tensor attributes {sym_visibility = "private"} { + // CHECK-NOT: "mhlo.create_token" + + // CHECK: [[SEND_ARG0_TOKEN:%.*]] = "mhlo.send"([[CALLEE_ARG0]], [[CALLEE_ARG1]]) + // CHECK: [[RECV_RETVAL0_TUPLE:%.*]] = "mhlo.recv"([[SEND_ARG0_TOKEN]]) + // CHECK: [[RECV_RETVAL0_VAL:%.*]] = "mhlo.get_tuple_element"([[RECV_RETVAL0_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: [[RECV_RETVAL0_TOKEN:%.*]] = "mhlo.get_tuple_element"([[RECV_RETVAL0_TUPLE]]) + // CHECK-SAME: index = 1 + %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "recv", send_key = "send", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK: return [[RECV_RETVAL0_VAL]], [[RECV_RETVAL0_TOKEN]] + return %0 : tensor +} + +// ----- + +// Test public function with TF/XLA communication op being called by another +// function. The original public function will be modified to create a token, +// while the function is cloned and rewritten with an extra token argument and +// an extra token result. All callers to the original function are updated to +// point to the cloned function and the function the caller is in is updated to +// pass a token or create a token. + +// CHECK: func @main([[MAIN_ARG0:%.*]]: tensor) -> tensor +func @main(%arg0: tensor) -> tensor { + // CHECK: [[MAIN_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[CALL:%.*]]:2 = call [[CALLEE_CLONE:@.*]]([[MAIN_ARG0]], [[MAIN_TOKEN]]) + // CHECK-SAME: (tensor, !mhlo.token) -> (tensor, !mhlo.token) + %0 = call @callee(%arg0) : (tensor) -> tensor + + // CHECK: return [[CALL]]#0 : tensor + return %0 : tensor +} + +// CHECK: func @callee([[CALLEE_ARG0:%.*]]: tensor) -> tensor +func @callee(%arg0: tensor) -> tensor { + // CHECK: [[CALLEE_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[SEND_ARG0_TOKEN:%.*]] = "mhlo.send"([[CALLEE_ARG0]], [[CALLEE_TOKEN]]) + // CHECK: [[RECV_RETVAL0_TUPLE:%.*]] = "mhlo.recv"([[SEND_ARG0_TOKEN]]) + // CHECK: [[RECV_RETVAL0_VAL:%.*]] = "mhlo.get_tuple_element"([[RECV_RETVAL0_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: [[RECV_RETVAL0_TOKEN:%.*]] = "mhlo.get_tuple_element"([[RECV_RETVAL0_TUPLE]]) + // CHECK-SAME: index = 1 + %0 = "tf._XlaHostComputeMlir"(%arg0) {recv_key = "recv", send_key = "send", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK: return [[RECV_RETVAL0_VAL]] + return %0 : tensor +} + +// CHECK: func [[CALLEE_CLONE]]([[CALLEE_CLONE_ARG0:%.*]]: tensor, [[CALLEE_CLONE_ARG1:%.*]]: !mhlo.token) -> (tensor, !mhlo.token) +// CHECK-NOT: "mhlo.create_token" + +// CHECK: [[CLONE_SEND_ARG0_TOKEN:%.*]] = "mhlo.send"([[CALLEE_CLONE_ARG0]], [[CALLEE_CLONE_ARG1]]) +// CHECK: [[CLONE_RECV_RETVAL0_TUPLE:%.*]] = "mhlo.recv"([[CLONE_SEND_ARG0_TOKEN]]) +// CHECK: [[CLONE_RECV_RETVAL0_VAL:%.*]] = "mhlo.get_tuple_element"([[CLONE_RECV_RETVAL0_TUPLE]]) +// CHECK-SAME: index = 0 +// CHECK: [[CLONE_RECV_RETVAL0_TOKEN:%.*]] = "mhlo.get_tuple_element"([[CLONE_RECV_RETVAL0_TUPLE]]) +// CHECK-SAME: index = 1 + +// CHECK: return [[CLONE_RECV_RETVAL0_VAL]], [[CLONE_RECV_RETVAL0_TOKEN]] + +// ----- + +// Tests generated tokens are passed into a function call that also has TF/XLA +// communication ops. + +// CHECK: func @main([[MAIN_ARG0:%.*]]: tensor) +func @main(%arg0: tensor) { + // CHECK: [[MAIN_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[MAIN_SEND0_TOKEN:%.*]] = "mhlo.send"([[MAIN_ARG0]], [[MAIN_TOKEN]]) + "tf.XlaSendToHost"(%arg0) {key = "send0"} : (tensor) -> () + + // CHECK: [[CALL_TOKEN:%.*]] = call @callee([[MAIN_SEND0_TOKEN]]) + // CHECK-SAME: (!mhlo.token) -> !mhlo.token + call @callee() : () -> () + + // CHECK: [[MAIN_SEND2_TOKEN:%.*]] = "mhlo.send"([[MAIN_ARG0]], [[CALL_TOKEN]]) + "tf.XlaSendToHost"(%arg0) {key = "send2"} : (tensor) -> () + return +} + +// CHECK: func @callee([[CALLEE_ARG0:%.*]]: !mhlo.token) -> !mhlo.token +func @callee() attributes {sym_visibility = "private"} { + // CHECK-NOT: "mhlo.create_token" + + // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0> + %0 = mhlo.constant dense<0> : tensor + + // CHECK: [[CALLEE_SEND_TOKEN:%.*]] = "mhlo.send"([[ZERO]], [[CALLEE_ARG0]]) + "tf.XlaSendToHost"(%0) {key = "send1"} : (tensor) -> () + + // CHECK: return [[CALLEE_SEND_TOKEN]] + return +} + +// ----- + +// Test only the top level function generates a token. + +// CHECK: func @callee0() +func @callee0() attributes {sym_visibility = "private"} { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: call @callee1([[INIT_TOKEN]]) + call @callee1() : () -> () + return +} + +// CHECK: func @callee1([[CALLEE1_ARG0:%.*]]: !mhlo.token) -> !mhlo.token +func @callee1() attributes {sym_visibility = "private"} { + // CHECK-NOT: "mhlo.create_token" + + // CHECK: [[CALL_2:%.*]] = call @callee2([[CALLEE1_ARG0]]) + call @callee2() : () -> () + + // CHECK: return [[CALL_2]] + return +} + +// CHECK: func @callee2([[CALLEE2_ARG0:%.*]]: !mhlo.token) -> !mhlo.token +func @callee2() attributes {sym_visibility = "private"} { + // CHECK-NOT: "mhlo.create_token" + + // CHECK: [[RECV_TUPLE:%.*]] = "mhlo.recv"([[CALLEE2_ARG0]]) + // CHECK: [[RECV_VAL:%.*]] = "mhlo.get_tuple_element"([[RECV_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: [[RECV_TOKEN:%.*]] = "mhlo.get_tuple_element"([[RECV_TUPLE]]) + // CHECK-SAME: index = 1 + %0 = "tf.XlaRecvFromHost"() {key = "recv_key", shape = #tf.shape<>} : () -> tensor + + // CHECK: return [[RECV_TOKEN]] + return +} + +// ----- + +// Test cloned function rewrite also checks transitive function calls to +// TF/XLA communication ops. + +// CHECK: func @callee3() +func @callee3() { + // CHECK: [[CALLEE3_INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: call @callee4{{.+}}([[CALLEE3_INIT_TOKEN]]) + call @callee4() : () -> () + return +} + +// CHECK: func @callee4() +func @callee4() { + // CHECK: [[CALLEE4_INIT_TOKEN:%.*]] = "mhlo.create_token" + + // CHECK: [[CALL_5:%.*]] = call @callee5([[CALLEE4_INIT_TOKEN]]) + call @callee5() : () -> () + + // CHECK: return + return +} + +// CHECK: func @callee5([[CALLEE5_ARG0:%.*]]: !mhlo.token) -> !mhlo.token +func @callee5() attributes {sym_visibility = "private"} { + // CHECK-NOT: "mhlo.create_token" + + // CHECK: [[RECV_TUPLE:%.*]] = "mhlo.recv"([[CALLEE5_ARG0]]) + // CHECK: [[RECV_VAL:%.*]] = "mhlo.get_tuple_element"([[RECV_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: [[RECV_TOKEN:%.*]] = "mhlo.get_tuple_element"([[RECV_TUPLE]]) + // CHECK-SAME: index = 1 + %0 = "tf.XlaRecvFromHost"() {key = "recv_key", shape = #tf.shape<>} : () -> tensor + + // CHECK: return [[RECV_TOKEN]] + return +} + +// CHECK: func @callee4{{.+}}([[CALLEE4_ARG0:%.*]]: !mhlo.token) -> !mhlo.token attributes {sym_visibility = "private"} +// CHECK-NOT: "mhlo.create_token" +// CHECK: [[CALL_5:%.*]] = call @callee5([[CALLEE4_ARG0]]) +// CHECK: return [[CALL_5]] + +// ----- + +// Tests `mhlo.if` with branches populated with TF/XLA communication ops. + +// CHECK-LABEL: func @if_both_branches +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor) +func @if_both_branches(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[TRUE_TUPLE:%.*]] = "mhlo.tuple"([[ARG1]], [[INIT_TOKEN]]) + // CHECK: [[FALSE_TUPLE:%.*]] = "mhlo.tuple"([[ARG2]], [[INIT_TOKEN]]) + + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if"([[ARG0]], [[TRUE_TUPLE]], [[FALSE_TUPLE]]) + %0 = "mhlo.if"(%arg0, %arg1, %arg2) ( { + // CHECK: ^bb0([[TRUE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[TRUE_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[TRUE_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 1 + + // CHECK: [[TRUE_SEND_TOKEN:%.*]] = "mhlo.send"([[TRUE_REGION_ARG_VALUE]], [[TRUE_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_if_true_dtoh_0"} + + // CHECK: [[TRUE_RECV_TUPLE:%.*]] = "mhlo.recv"([[TRUE_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_if_true_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg3) {recv_key = "recv_if_true", send_key = "send_if_true", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[TRUE_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[TRUE_RECV_TUPLE]]) {index = 1 + // CHECK: [[TRUE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[TRUE_GET_TUPLE_ELEMENT0]], [[TRUE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[TRUE_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + }, { + // CHECK: ^bb0([[FALSE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[FALSE_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[FALSE_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 1 + + // CHECK: [[FALSE_SEND_TOKEN:%.*]] = "mhlo.send"([[FALSE_REGION_ARG_VALUE]], [[FALSE_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 3 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_if_false_dtoh_0"} + + // CHECK: [[FALSE_RECV_TUPLE:%.*]] = "mhlo.recv"([[FALSE_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 4 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_if_false_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg3) {recv_key = "recv_if_false", send_key = "send_if_false", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[FALSE_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[FALSE_RECV_TUPLE]]) {index = 1 + // CHECK: [[FALSE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[FALSE_GET_TUPLE_ELEMENT0]], [[FALSE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[FALSE_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + + // CHECK: (tensor, tuple, !mhlo.token>, tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor, tensor, tensor) -> tensor + + // CHECK: [[IF_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[IF_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.if` with only the `true` branch populated with TF/XLA +// communication ops. + +// CHECK-LABEL: func @if_true_branch +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor) +func @if_true_branch(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[TRUE_TUPLE:%.*]] = "mhlo.tuple"([[ARG1]], [[INIT_TOKEN]]) + // CHECK: [[FALSE_TUPLE:%.*]] = "mhlo.tuple"([[ARG2]], [[INIT_TOKEN]]) + + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if"([[ARG0]], [[TRUE_TUPLE]], [[FALSE_TUPLE]]) + %0 = "mhlo.if"(%arg0, %arg1, %arg2) ( { + // CHECK: ^bb0([[TRUE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[TRUE_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[TRUE_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 1 + + // CHECK: [[TRUE_SEND_TOKEN:%.*]] = "mhlo.send"([[TRUE_REGION_ARG_VALUE]], [[TRUE_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_if_true_dtoh_0"} + + // CHECK: [[TRUE_RECV_TUPLE:%.*]] = "mhlo.recv"([[TRUE_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_if_true_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg3) {recv_key = "recv_if_true", send_key = "send_if_true", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[TRUE_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[TRUE_RECV_TUPLE]]) {index = 1 + // CHECK: [[TRUE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[TRUE_GET_TUPLE_ELEMENT0]], [[TRUE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[TRUE_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + }, { + // CHECK: ^bb0([[FALSE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 1 + // CHECK: [[FALSE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[FALSE_GET_TUPLE_ELEMENT0]], [[FALSE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[FALSE_RETURN_TUPLE]]) + "mhlo.return"(%arg3) : (tensor) -> () + + // CHECK: (tensor, tuple, !mhlo.token>, tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor, tensor, tensor) -> tensor + + // CHECK: [[IF_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[IF_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.if` with only the `false` branch populated with TF/XLA +// communication ops. + +// CHECK-LABEL: func @if_false_branch +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor) +func @if_false_branch(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[TRUE_TUPLE:%.*]] = "mhlo.tuple"([[ARG1]], [[INIT_TOKEN]]) + // CHECK: [[FALSE_TUPLE:%.*]] = "mhlo.tuple"([[ARG2]], [[INIT_TOKEN]]) + + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if"([[ARG0]], [[TRUE_TUPLE]], [[FALSE_TUPLE]]) + %0 = "mhlo.if"(%arg0, %arg1, %arg2) ( { + // CHECK: ^bb0([[TRUE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[TRUE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 1 + // CHECK: [[TRUE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[TRUE_GET_TUPLE_ELEMENT0]], [[TRUE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[TRUE_RETURN_TUPLE]]) + "mhlo.return"(%arg3) : (tensor) -> () + }, { + // CHECK: ^bb0([[FALSE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[FALSE_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[FALSE_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[FALSE_REGION_ARG]]) {index = 1 + + // CHECK: [[FALSE_SEND_TOKEN:%.*]] = "mhlo.send"([[FALSE_REGION_ARG_VALUE]], [[FALSE_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_if_false_dtoh_0"} + + // CHECK: [[FALSE_RECV_TUPLE:%.*]] = "mhlo.recv"([[FALSE_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_if_false_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg3) {recv_key = "recv_if_false", send_key = "send_if_false", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[FALSE_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[FALSE_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[FALSE_RECV_TUPLE]]) {index = 1 + // CHECK: [[FALSE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[FALSE_GET_TUPLE_ELEMENT0]], [[FALSE_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[FALSE_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + + // CHECK: (tensor, tuple, !mhlo.token>, tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor, tensor, tensor) -> tensor + + // CHECK: [[IF_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[IF_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.if` with tuple arg from a `mhlo.tuple` only used by `mhlo.if` is +// replaced. + +// CHECK-LABEL: func @if_replace_tuple_arg +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor) +func @if_replace_tuple_arg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-NOT: "mhlo.tuple"([[ARG1]], [[ARG2]]) + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[IF_ARG_TUPLE:%.*]] = "mhlo.tuple"([[ARG1]], [[ARG2]], [[INIT_TOKEN]]) + %0 = "mhlo.tuple"(%arg1, %arg2) : (tensor, tensor) -> tuple, tensor> + + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if"([[ARG0]], [[IF_ARG_TUPLE]], [[IF_ARG_TUPLE]]) + %1 = "mhlo.if"(%arg0, %0, %0) ( { + ^bb0(%arg3: tuple, tensor>): + %2 = "mhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple, tensor>) -> tensor + "tf.XlaSendToHost"(%2) {key = "send_key"} : (tensor) -> () + "mhlo.return"(%2) : (tensor) -> () + }, { + ^bb0(%arg3: tuple, tensor>): + %2 = "mhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple, tensor>) -> tensor + "mhlo.return"(%2) : (tensor) -> () + }) : (tensor, tuple, tensor>, tuple, tensor>) -> tensor + return %1 : tensor +} + +// ----- + +// Tests `mhlo.if` with tuple arg not from a `mhlo.tuple` is unpacked. + +// CHECK-LABEL: func @if_unpack_tuple_arg +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tuple, tensor>) +func @if_unpack_tuple_arg(%arg0: tensor, %arg1: tuple, tensor>) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK-DAG: [[IF_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[ARG1]]) {index = 0 + // CHECK-DAG: [[IF_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[ARG1]]) {index = 1 + // CHECK: [[IF_ARG_TUPLE:%.*]] = "mhlo.tuple"([[IF_ARG_ELEMENT0]], [[IF_ARG_ELEMENT1]], [[INIT_TOKEN]]) + + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if"([[ARG0]], [[IF_ARG_TUPLE]], [[IF_ARG_TUPLE]]) + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + ^bb0(%arg2: tuple, tensor>): + %1 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple, tensor>) -> tensor + "tf.XlaSendToHost"(%1) {key = "send_key"} : (tensor) -> () + "mhlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg2: tuple, tensor>): + %1 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple, tensor>) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }) : (tensor, tuple, tensor>, tuple, tensor>) -> tensor + return %0 : tensor +} + +// ----- + +// Tests `mhlo.if` tuple result is extended with a `mhlo.token`. + +// CHECK-LABEL: func @if_extend_tuple_result +func @if_extend_tuple_result(%arg0: tensor, %arg1: tuple, tensor>) -> tuple, tensor> { + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if" + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + ^bb0(%arg2: tuple, tensor>): + %1 = "mhlo.get_tuple_element"(%arg2) {index = 0 : i32} : (tuple, tensor>) -> tensor + "tf.XlaSendToHost"(%1) {key = "send_key"} : (tensor) -> () + "mhlo.return"(%arg2) : (tuple, tensor>) -> () + }, { + ^bb0(%arg2: tuple, tensor>): + "mhlo.return"(%arg2) : (tuple, tensor>) -> () + // CHECK: (tensor, tuple, tensor, !mhlo.token>, tuple, tensor, !mhlo.token>) -> tuple, tensor, !mhlo.token> + }) : (tensor, tuple, tensor>, tuple, tensor>) -> tuple, tensor> + + // CHECK-DAG: [[IF_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) {index = 0 + // CHECK-DAG: [[IF_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) {index = 1 + // CHECK: [[IF_SUBTUPLE_RESULT:%.*]] = "mhlo.tuple"([[IF_TUPLE_ELEMENT0]], [[IF_TUPLE_ELEMENT1]]) + // CHECK: return [[IF_SUBTUPLE_RESULT]] + return %0 : tuple, tensor> +} + +// ----- + +// Tests nested `mhlo.if` containing TF/XLA communication ops. + +// CHECK-LABEL: func @if_nested +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor) +func @if_nested(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[OUTER_IF_ARG_TUPLE:%.*]] = "mhlo.tuple"([[ARG1]], [[INIT_TOKEN]]) + + // CHECK: "mhlo.if"([[ARG0]], [[OUTER_IF_ARG_TUPLE]], [[OUTER_IF_ARG_TUPLE]]) + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + // CHECK-NEXT: ^bb0([[OUTER_IF_TRUE_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg2: tensor): + // CHECK-DAG: [[OUTER_IF_TRUE_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[OUTER_IF_TRUE_ARG]]) {index = 0 + // CHECK-DAG: [[OUTER_IF_TRUE_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[OUTER_IF_TRUE_ARG]]) {index = 1 + // CHECK: [[INNER_IF_ARG_TUPLE:%.*]] = "mhlo.tuple"([[OUTER_IF_TRUE_ARG_ELEMENT0]], [[OUTER_IF_TRUE_ARG_ELEMENT1]]) + + %1 = mhlo.constant dense : tensor + + // CHECK: [[INNER_IF_TUPLE:%.*]] = "mhlo.if"({{%.*}}, [[INNER_IF_ARG_TUPLE]], [[INNER_IF_ARG_TUPLE]]) + %2 = "mhlo.if"(%1, %arg2, %arg2) ( { + // CHECK-NEXT: ^bb0([[INNER_IF_TRUE_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[INNER_IF_TRUE_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[INNER_IF_TRUE_ARG]]) {index = 0 + // CHECK-DAG: [[INNER_IF_TRUE_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[INNER_IF_TRUE_ARG]]) {index = 1 + + // CHECK: [[SEND_TOKEN:%.*]] = "mhlo.send"([[INNER_IF_TRUE_ARG_ELEMENT0]], [[INNER_IF_TRUE_ARG_ELEMENT1]]) + "tf.XlaSendToHost"(%arg3) {key = "send_key"} : (tensor) -> () + + // CHECK: [[INNER_IF_TRUE_RESULT:%.*]] = "mhlo.tuple"([[INNER_IF_TRUE_ARG_ELEMENT0]], [[SEND_TOKEN]]) + // CHECK: "mhlo.return"([[INNER_IF_TRUE_RESULT]]) + "mhlo.return"(%arg3) : (tensor) -> () + + // CHECK-NEXT: }, { + }, { + + // CHECK-NEXT: ^bb0([[INNER_IF_FALSE_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg3: tensor): + // CHECK-DAG: [[INNER_IF_FALSE_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[INNER_IF_FALSE_ARG]]) {index = 0 + // CHECK-DAG: [[INNER_IF_FALSE_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[INNER_IF_FALSE_ARG]]) {index = 1 + // CHECK: [[INNER_IF_FALSE_RESULT:%.*]] = "mhlo.tuple"([[INNER_IF_FALSE_ARG_ELEMENT0]], [[INNER_IF_FALSE_ARG_ELEMENT1]]) + // CHECK: "mhlo.return"([[INNER_IF_FALSE_RESULT]]) + "mhlo.return"(%arg3) : (tensor) -> () + // CHECK-NEXT: (tensor, tuple, !mhlo.token>, tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor, tensor, tensor) -> tensor + + // CHECK-DAG: [[INNER_IF_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[INNER_IF_TUPLE]]) {index = 1 + // CHECK: [[OUTER_IF_TRUE_RESULT:%.*]] = "mhlo.tuple"([[OUTER_IF_TRUE_ARG_ELEMENT0]], [[INNER_IF_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[OUTER_IF_TRUE_RESULT]]) + "mhlo.return"(%arg2) : (tensor) -> () + + // CHECK-NEXT: }, { + }, { + + // CHECK-NEXT: ^bb0([[OUTER_IF_FALSE_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg2: tensor): + // CHECK-DAG: [[OUTER_IF_FALSE_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[OUTER_IF_FALSE_ARG]]) {index = 0 + // CHECK-DAG: [[OUTER_IF_FALSE_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[OUTER_IF_FALSE_ARG]]) {index = 1 + // CHECK: [[OUTER_IF_FALSE_RESULT:%.*]] = "mhlo.tuple"([[OUTER_IF_FALSE_ARG_ELEMENT0]], [[OUTER_IF_FALSE_ARG_ELEMENT1]]) + // CHECK: "mhlo.return"([[OUTER_IF_FALSE_RESULT]]) + "mhlo.return"(%arg2) : (tensor) -> () + // CHECK-NEXT: (tensor, tuple, !mhlo.token>, tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// Tests `mhlo.if` containing a function call to TF/XLA communication ops. + +// CHECK-LABEL: func @if_function_call +func @if_function_call(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.if" + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + // CHECK: ^bb0([[TRUE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg2: tensor): + // CHECK-DAG: [[TRUE_REGION_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[TRUE_REGION_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 1 + // CHECK: [[CALL_TOKEN:%.*]] = call @callee([[TRUE_REGION_ARG_ELEMENT0]], [[TRUE_REGION_ARG_ELEMENT1]]) + call @callee(%arg2) : (tensor) -> () + + // CHECK: [[TRUE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[TRUE_REGION_ARG_ELEMENT0]], [[CALL_TOKEN]]) + // CHECK: "mhlo.return"([[TRUE_RETURN_TUPLE]]) + "mhlo.return"(%arg2) : (tensor) -> () + }, { + ^bb0(%arg2: tensor): + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @callee +// CHECK-SAME: ([[CALLEE_ARG0:%.*]]: tensor, [[CALLEE_ARG1:%.*]]: !mhlo.token) -> !mhlo.token +func @callee(%arg0: tensor) attributes {sym_visibility = "private"} { + // CHECK: [[SEND_TOKEN:%.*]] = "mhlo.send" + "tf.XlaSendToHost"(%arg0) {key = "send_key"} : (tensor) -> () + + // CHECK: return [[SEND_TOKEN]] + return +} + +// ----- + +// Tests `mhlo.if` containing multiple TF/XLA communication ops. + +// CHECK-LABEL: func @if_region_multiple_ops +func @if_region_multiple_ops(%arg0: tensor, %arg1: tensor) { + // CHECK: "mhlo.if" + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + // CHECK: ^bb0([[TRUE_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg2: tensor): + // CHECK: [[TRUE_REGION_ARG_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 0 + // CHECK: [[TRUE_REGION_ARG_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[TRUE_REGION_ARG]]) {index = 1 + + // CHECK: [[SEND0_TOKEN:%.*]] = "mhlo.send"([[TRUE_REGION_ARG_ELEMENT0]], [[TRUE_REGION_ARG_ELEMENT1]]) + "tf.XlaSendToHost"(%arg2) {key = "send_key0"} : (tensor) -> () + + // CHECK: [[SEND1_TOKEN:%.*]] = "mhlo.send"([[TRUE_REGION_ARG_ELEMENT0]], [[SEND0_TOKEN]]) + "tf.XlaSendToHost"(%arg2) {key = "send_key1"} : (tensor) -> () + + // CHECK: [[TRUE_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[TRUE_REGION_ARG_ELEMENT0]], [[SEND1_TOKEN]]) + // CHECK: "mhlo.return"([[TRUE_RETURN_TUPLE]]) + "mhlo.return"(%arg2) : (tensor) -> () + }, { + ^bb0(%arg2: tensor): + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return +} + +// ----- + +// Tests `mhlo.if` containing TF/XLA communication ops followed by other TF/XLA +// communication ops. + +func @if_followed_by_communication_op(%arg0: tensor, %arg1: tensor) { + // CHECK: [[IF_TUPLE:%.*]] = "mhlo.if" + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + ^bb0(%arg2: tensor): + "tf.XlaSendToHost"(%arg2) {key = "send_key0"} : (tensor) -> () + "mhlo.return"(%arg2) : (tensor) -> () + }, { + ^bb0(%arg2: tensor): + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + + // CHECK: [[IF_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[IF_TUPLE]]) {index = 1 + + // CHECK: "mhlo.send"({{.*}}, [[IF_TUPLE_ELEMENT1]]) + "tf.XlaSendToHost"(%arg1) {key = "send_key1"} : (tensor) -> () + return +} + +// ----- + +// Tests `mhlo.while` with cond and body populated with TF/XLA communication +// ops. + +// CHECK-LABEL: func @while_cond_body +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @while_cond_body(%arg0: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[ARG_TUPLE:%.*]] = "mhlo.tuple"([[ARG0]], [[INIT_TOKEN]]) + + // CHECK: [[WHILE_TUPLE:%.*]] = "mhlo.while"([[ARG_TUPLE]]) + %0 = "mhlo.while"(%arg0) ( { + // CHECK: ^bb0([[COND_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[COND_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[COND_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 1 + + // CHECK: [[COND_SEND_TOKEN:%.*]] = "mhlo.send"([[COND_REGION_ARG_VALUE]], [[COND_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_while_cond_dtoh_0"} + + // CHECK: [[COND_RECV_TUPLE:%.*]] = "mhlo.recv"([[COND_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_while_cond_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_cond", send_key = "send_while_cond", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[COND_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[COND_RECV_TUPLE]]) {index = 1 + + // CHECK: [[COND_COMPARE:%.*]] = "mhlo.compare"([[COND_GET_TUPLE_ELEMENT0]], [[COND_GET_TUPLE_ELEMENT0]]) + %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + + // CHECK: "mhlo.return"([[COND_COMPARE]]) + "mhlo.return"(%2) : (tensor) -> () + }, { + // CHECK: ^bb0([[BODY_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[BODY_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[BODY_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 1 + + // CHECK: [[BODY_SEND_TOKEN:%.*]] = "mhlo.send"([[BODY_REGION_ARG_VALUE]], [[BODY_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 3 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_while_body_dtoh_0"} + + // CHECK: [[BODY_RECV_TUPLE:%.*]] = "mhlo.recv"([[BODY_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 4 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_while_body_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_body", send_key = "send_while_body", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[BODY_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[BODY_RECV_TUPLE]]) {index = 1 + // CHECK: [[BODY_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[BODY_GET_TUPLE_ELEMENT0]], [[BODY_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[BODY_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + // CHECK: (tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor) -> tensor + + // CHECK: [[WHILE_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[WHILE_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[WHILE_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.while` with only the `cond` region populated with TF/XLA +// communication ops. + +// CHECK-LABEL: func @while_cond +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @while_cond(%arg0: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[ARG_TUPLE:%.*]] = "mhlo.tuple"([[ARG0]], [[INIT_TOKEN]]) + + // CHECK: [[WHILE_TUPLE:%.*]] = "mhlo.while"([[ARG_TUPLE]]) + %0 = "mhlo.while"(%arg0) ( { + // CHECK: ^bb0([[COND_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[COND_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[COND_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 1 + + // CHECK: [[COND_SEND_TOKEN:%.*]] = "mhlo.send"([[COND_REGION_ARG_VALUE]], [[COND_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_while_cond_dtoh_0"} + + // CHECK: [[COND_RECV_TUPLE:%.*]] = "mhlo.recv"([[COND_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_while_cond_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_cond", send_key = "send_while_cond", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[COND_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[COND_RECV_TUPLE]]) {index = 1 + + // CHECK: [[COND_COMPARE:%.*]] = "mhlo.compare"([[COND_GET_TUPLE_ELEMENT0]], [[COND_GET_TUPLE_ELEMENT0]]) + %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + + // CHECK: "mhlo.return"([[COND_COMPARE]]) + "mhlo.return"(%2) : (tensor) -> () + }, { + // CHECK: ^bb0([[BODY_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 1 + // CHECK: [[BODY_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[BODY_GET_TUPLE_ELEMENT0]], [[BODY_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[BODY_RETURN_TUPLE]]) + "mhlo.return"(%arg1) : (tensor) -> () + // CHECK: (tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor) -> tensor + + // CHECK: [[WHILE_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[WHILE_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[WHILE_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.while` with only the `body` region populated with TF/XLA +// communication ops. + +// CHECK-LABEL: func @while_body +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @while_body(%arg0: tensor) -> tensor { + // CHECK: [[INIT_TOKEN:%.*]] = "mhlo.create_token" + // CHECK: [[ARG_TUPLE:%.*]] = "mhlo.tuple"([[ARG0]], [[INIT_TOKEN]]) + + // CHECK: [[WHILE_TUPLE:%.*]] = "mhlo.while"([[ARG_TUPLE]]) + %0 = "mhlo.while"(%arg0) ( { + // CHECK: ^bb0([[COND_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[COND_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[COND_REGION_ARG]]) {index = 1 + + // CHECK: [[COND_COMPARE:%.*]] = "mhlo.compare"([[COND_GET_TUPLE_ELEMENT0]], [[COND_GET_TUPLE_ELEMENT0]]) + %2 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + + // CHECK: "mhlo.return"([[COND_COMPARE]]) + "mhlo.return"(%2) : (tensor) -> () + }, { + // CHECK: ^bb0([[BODY_REGION_ARG:%.*]]: tuple, !mhlo.token>): + ^bb0(%arg1: tensor): + // CHECK-DAG: [[BODY_REGION_ARG_VALUE:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 0 + // CHECK-DAG: [[BODY_REGION_ARG_TOKEN:%.*]] = "mhlo.get_tuple_element"([[BODY_REGION_ARG]]) {index = 1 + + // CHECK: [[BODY_SEND_TOKEN:%.*]] = "mhlo.send"([[BODY_REGION_ARG_VALUE]], [[BODY_REGION_ARG_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "send_while_body_dtoh_0"} + + // CHECK: [[BODY_RECV_TUPLE:%.*]] = "mhlo.recv"([[BODY_SEND_TOKEN]]) + // CHECK-SAME: channel_id = {handle = 2 : i64, type = 3 : i64} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "f32", _xla_host_transfer_rendezvous = "recv_while_body_htod_0"} + %1 = "tf._XlaHostComputeMlir"(%arg1) {recv_key = "recv_while_body", send_key = "send_while_body", tpu_core = 0 : i64} : (tensor) -> tensor + + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[BODY_RECV_TUPLE]]) {index = 0 + // CHECK-DAG: [[BODY_GET_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[BODY_RECV_TUPLE]]) {index = 1 + // CHECK: [[BODY_RETURN_TUPLE:%.*]] = "mhlo.tuple"([[BODY_GET_TUPLE_ELEMENT0]], [[BODY_GET_TUPLE_ELEMENT1]]) + // CHECK: "mhlo.return"([[BODY_RETURN_TUPLE]]) + "mhlo.return"(%1) : (tensor) -> () + // CHECK: (tuple, !mhlo.token>) -> tuple, !mhlo.token> + }) : (tensor) -> tensor + + // CHECK: [[WHILE_TUPLE_ELEMENT0:%.*]] = "mhlo.get_tuple_element"([[WHILE_TUPLE]]) + // CHECK-SAME: index = 0 + // CHECK: return [[WHILE_TUPLE_ELEMENT0]] + return %0 : tensor +} + +// ----- + +// Tests `mhlo.while` containing TF/XLA communication ops followed by other +// TF/XLA communication ops. + +func @while_followed_by_communication_op(%arg0: tensor) { + // CHECK: [[WHILE_TUPLE:%.*]] = "mhlo.while" + %0 = "mhlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + "tf.XlaSendToHost"(%arg1) {key = "send_key0"} : (tensor) -> () + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + + // CHECK: [[WHILE_TUPLE_ELEMENT1:%.*]] = "mhlo.get_tuple_element"([[WHILE_TUPLE]]) {index = 1 + + // CHECK: "mhlo.send"({{.*}}, [[WHILE_TUPLE_ELEMENT1]]) + "tf.XlaSendToHost"(%arg0) {key = "send_key1"} : (tensor) -> () + return +} + +// ----- + +// Tests unsupported parent of TF/XLA communication op. + +func @unsupported_ancestor(%arg0: tensor, %arg1: tensor) { + %0 = "mhlo.reduce"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + // expected-error@+1 {{expects ancestor(s) to be of ['mhlo.if', 'func']}} + "tf._XlaHostComputeMlir"() {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0 : i64} : () -> () + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor + return +} + +// ----- + +// Tests transitive unsupported parent of TF/XLA communication op. + +func @unsupported_ancestor(%arg0: tensor, %arg1: tensor) { + %0 = "mhlo.reduce"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + // expected-error@+1 {{expects ancestor(s) to be of ['mhlo.if', 'func']}} + call @callee() : () -> () + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor + return +} + +func @callee() attributes {sym_visibility = "private"} { + "tf._XlaHostComputeMlir"() {recv_key = "host_compute_channel_recv", send_key = "host_compute_channel_send", tpu_core = 0 : i64} : () -> () + return +} + +// ----- + +// Tests unsupported `mhlo.if` with region of more than one block and contains a +// TF/XLA communication op. + +func @if_multiple_blocks(%arg0: tensor, %arg1: tensor) { + %0 = "mhlo.if"(%arg0, %arg1, %arg1) ( { + ^bb0(%arg2: tensor): + br ^bb1(%arg2 : tensor) + ^bb1(%arg3: tensor): + // expected-error@+1 {{expects single block region ancestor(s)}} + "tf.XlaSendToHost"(%arg3) {key = "send_key0"} : (tensor) -> () + "mhlo.return"(%arg3) : (tensor) -> () + }, { + ^bb0(%arg2: tensor): + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return +} + +// ----- + +// Tests function with more than one block that is to be rewritten emits an +// error instead. + +// expected-error@+1 {{'func' ops with more than one block are not supported}} +func @multi_block_func() { + br ^bb1 +^bb1: + %0 = "tf.XlaRecvFromHost"() {key = "recv_key", shape = #tf.shape<>} : () -> tensor + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-include-tf2xla-fallback.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-include-tf2xla-fallback.mlir new file mode 100644 index 00000000000..9f72820d15b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-include-tf2xla-fallback.mlir @@ -0,0 +1,50 @@ +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=false" -verify-diagnostics %s | FileCheck --check-prefix NO_FALLBACK %s +// RUN: tf-opt "-xla-legalize-tf=use-tf2xla-fallback=true device-type=XLA_CPU_JIT" -verify-diagnostics %s | FileCheck --check-prefix SUPPORTED_FALLBACK_DEVICE %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true" %s | FileCheck --check-prefix UNSPECIFIED_FALLBACK_DEVICE %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion use-tf2xla-fallback=true device-type=INVALID_DEVICE_TYPE" %s | FileCheck --check-prefix UNSUPPORTED_FALLBACK_DEVICE %s + +// We run this test four times: +// 1) Legalize without using TF2XLA fallback (ops cannot be legalized). +// 2) Use fallback with a device that supports all ops (ops can be legalized). +// 3) Use fallback with unspecified device (ops cannot be legalized). +// 4) Use fallback with specified but unsupported device (ops cannot be legalized). +// +// Note: For 3) and 4) we do not use `-verify-diagnostics` because these cases +// produce remarks that don't occur for 1) and 2) and there is no way to check +// the remarks only for 3) and 4) (except using two files). + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + +// CHECK-LABEL: non_max_suppression_v4 +func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor) -> tensor<2xi32> { + %max_size = mhlo.constant dense<2> : tensor + // NO_FALLBACK: tf.NonMaxSuppressionV4 + // SUPPORTED_FALLBACK_DEVICE-NOT: tf.NonMaxSuppressionV4 + // UNSPECIFIED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4 + // UNSUPPORTED_FALLBACK_DEVICE: tf.NonMaxSuppressionV4 + %0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>, tensor) + return %0#0 : tensor<2xi32> +} + +// CHECK-LABEL: mirror_pad +func @mirror_pad(%arg0: tensor<2x3xcomplex>) -> tensor<4x7xcomplex> { + %0 = mhlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32> + // NO_FALLBACK: tf.MirrorPad + // SUPPORTED_FALLBACK_DEVICE-NOT: tf.MirrorPad + // UNSPECIFIED_FALLBACK_DEVICE: tf.MirrorPad + // UNSUPPORTED_FALLBACK_DEVICE: tf.MirrorPad + %1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex>, tensor<2x2xi32>) -> tensor<4x7xcomplex> + return %1 : tensor<4x7xcomplex> +} + +// CHECK-LABEL: atan2 +func @atan2(%arg0: tensor<4x1xf32>, %arg1: tensor<4x1x4xf32>) -> tensor<4x4x4xf32> { + // NO_FALLBACK: tf.Atan2 + // SUPPORTED_FALLBACK_DEVICE-NOT: tf.Atan2 + // UNSPECIFIED_FALLBACK_DEVICE: tf.Atan2 + // UNSUPPORTED_FALLBACK_DEVICE: tf.Atan2 + %0 = "tf.Atan2"(%arg0, %arg1) : (tensor<4x1xf32>, tensor<4x1x4xf32>) -> tensor<4x4x4xf32> + return %0: tensor<4x4x4xf32> +} + +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index ad4ef4b8f77..cd351447303 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -116,8 +116,7 @@ func @convert(%arg0: tensor<2xi32>) -> tensor<2xf32> { // CHECK-LABEL: func @constant func @constant(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: %[[SCALAR_ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[ONE:.*]] = "mhlo.broadcast_in_dim"(%[[SCALAR_ONE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<2xf32> + // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2xf32> // CHECK: %[[RESULT:.*]] = mhlo.divide %[[ONE]], %arg0 : tensor<2xf32> // CHECK: return %[[RESULT]] @@ -199,7 +198,6 @@ func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2 // CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor) func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor) -> tensor<3x3xf32> { -// CHECK: %[[CST:.*]] = mhlo.constant dense<3> : tensor<2xi32> // CHECK: %[[DEFAULT:.*]] = "mhlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor) -> tensor<3x3xf32> // CHECK: %[[RESULT:.*]] = "mhlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( { @@ -259,6 +257,14 @@ func @arg_min(%arg0: tensor<6xf64>) -> tensor { return %1 : tensor } +// CHECK-LABEL: non_max_suppression_v4 +func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor, %arg3: tensor) -> tensor<2xi32> { + %max_size = mhlo.constant dense<2> : tensor + // CHECK-NOT: tf.NonMaxSuppressionV4 + %0:2 = "tf.NonMaxSuppressionV4"(%arg0, %arg1, %max_size, %arg2, %arg3) {pad_to_max_output_size = true}: (tensor<3x4xf32>, tensor<3xf32>, tensor, tensor, tensor) -> (tensor<2xi32>, tensor) + return %0#0 : tensor<2xi32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 221fa19f77c..9b32fb97260 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1,7 +1,7 @@ -// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FileCheck %s -// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=false" %s | FILECHECK_OPTS="" FileCheck %s +// RUN: tf-opt "-xla-legalize-tf=allow-partial-conversion legalize-chlo=true" -verify-diagnostics %s | FileCheck %s --check-prefix CHLO --dump-input-filter=all // This test runs twice: -// 1. Through FileCheck with chlo legalization disabled since verifying +// 1. Through FILECHECK_OPTS="" FileCheck with chlo legalization disabled since verifying // that the chlo ops emit produces more useful tests. // 2. With chlo legalization enabled, verifying diagnostics to pick up any // issues with the full lowering (can catch some broadcasting corner @@ -26,6 +26,28 @@ func @fusedBatchNorm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, return %0#0 : tensor<8x8x8x8xf32> } +// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same +// code), so only do a couple of basic checks. + +// CHECK-LABEL: fusedBatchNormV2_noTraining +func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + return %0#0 : tensor<8x8x8x8xf32> +} + +// CHECK-LABEL: fusedBatchNormV2_training +func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: %[[RESULT0:.*]] = "mhlo.batch_norm_training"({{.*}}, %arg1, %arg2) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tuple, tensor<8xf32>, tensor<8xf32>> + %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 0 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8x8x8x8xf32> + // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> + // CHECK: mhlo.constant + // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> + return %0#0 : tensor<8x8x8x8xf32> +} + // CHECK-LABEL: fusedBatchNormV3_noTraining func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { // CHECK: "mhlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> @@ -473,6 +495,142 @@ func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { return %0: tensor<4x3xf32> } +//===----------------------------------------------------------------------===// +// MatrixDiagPart +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @matrix_diag_part +// CHECK-SAME: %[[ARG:.*]]: tensor<7x140x128xi32> +func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + // CHECK-DAG: %[[V0:.*]] = mhlo.constant dense<42> : tensor + // CHECK-DAG: %[[V1:.*]] = mhlo.constant dense<[-10, 11]> : tensor<2xi32> + // CHECK-DAG: %[[V2:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V3:.*]] = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V4:.*]] = mhlo.constant dense<0> : tensor + // CHECK-DAG: %[[V5:.*]] = "mhlo.broadcast"(%[[V4]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V6:.*]] = mhlo.constant dense : tensor + // CHECK-DAG: %[[V7:.*]] = "mhlo.broadcast"(%[[V6]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi1> + // CHECK-DAG: %[[V8:.*]] = mhlo.constant dense : tensor + // CHECK-DAG: %[[V9:.*]] = "mhlo.broadcast"(%[[V8]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi1> + // CHECK-DAG: %[[V10:.*]] = mhlo.constant dense<11> : tensor + // CHECK-DAG: %[[V11:.*]] = "mhlo.broadcast"(%[[V10]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V12:.*]] = mhlo.constant dense<140> : tensor + // CHECK-DAG: %[[V13:.*]] = "mhlo.broadcast"(%[[V12]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V14:.*]] = mhlo.constant dense<128> : tensor + // CHECK-DAG: %[[V15:.*]] = "mhlo.broadcast"(%[[V14]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V16:.*]] = mhlo.constant dense<128> : tensor + // CHECK-DAG: %[[V17:.*]] = "mhlo.broadcast"(%[[V16]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V18:.*]] = mhlo.subtract %[[V11]], %[[V2]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V19:.*]] = "mhlo.negate"(%[[V18]]) : (tensor<1x22x128xi32>) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V20:.*]] = mhlo.minimum %[[V18]], %[[V5]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V21:.*]] = mhlo.add %[[V13]], %[[V20]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V22:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V23:.*]] = mhlo.subtract %[[V15]], %[[V22]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V24:.*]] = mhlo.minimum %[[V21]], %[[V23]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V25:.*]] = chlo.broadcast_compare %[[V18]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-DAG: %[[V26:.*]] = mhlo.subtract %[[V17]], %[[V24]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V27:.*]] = "mhlo.select"(%[[V25]], %[[V26]], %[[V5]]) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> + // CHECK-DAG: %[[V28:.*]] = mhlo.maximum %[[V18]], %[[V5]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V29:.*]] = mhlo.subtract %[[V28]], %[[V27]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V30:.*]] = mhlo.maximum %[[V19]], %[[V5]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V31:.*]] = mhlo.subtract %[[V30]], %[[V27]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V32:.*]] = mhlo.add %[[V3]], %[[V29]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V33:.*]] = mhlo.add %[[V3]], %[[V31]] : tensor<1x22x128xi32> + // CHECK-DAG: %[[V34:.*]] = chlo.broadcast_compare %[[V32]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-DAG: %[[V35:.*]] = chlo.broadcast_compare %[[V32]], %[[V15]] {comparison_direction = "LT"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-DAG: %[[V36:.*]] = mhlo.and %[[V34]], %[[V35]] : tensor<1x22x128xi1> + // CHECK-DAG: %[[V37:.*]] = chlo.broadcast_compare %[[V33]], %[[V5]] {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-DAG: %[[V38:.*]] = chlo.broadcast_compare %[[V33]], %[[V13]] {comparison_direction = "LT"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-DAG: %[[V39:.*]] = mhlo.and %[[V37]], %[[V38]] : tensor<1x22x128xi1> + // CHECK-DAG: %[[V40:.*]] = mhlo.and %[[V36]], %[[V39]] : tensor<1x22x128xi1> + // CHECK-DAG: %[[V41:.*]] = "mhlo.reshape"(%[[V40]]) : (tensor<1x22x128xi1>) -> tensor<22x128xi1> + // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) {dimension = 0 : i64} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> + // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) {dimension_numbers = {collapsed_slice_dims = dense<[1, 2]> : tensor<2xi64>, index_vector_dim = 0 : i64, offset_dims = dense<0> : tensor<1xi64>, start_index_map = dense<[1, 2]> : tensor<2xi64>}, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>} : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> + // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) {broadcast_sizes = dense<7> : tensor<1xi64>} : (tensor<22x128xi1>) -> tensor<7x22x128xi1> + // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) {broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<7x22x128xi32> + // CHECK: %[[V46:.*]] = "mhlo.select"(%[[V44]], %[[V43]], %[[V45]]) : (tensor<7x22x128xi1>, tensor<7x22x128xi32>, tensor<7x22x128xi32>) -> tensor<7x22x128xi32> + // CHECK: return %[[V46]] : tensor<7x22x128xi32> + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + return %2: tensor<7x22x128xi32> +} + +// CHECK-LABEL: func @matrix_diag_part_single_diagonal +func @matrix_diag_part_single_diagonal(%arg0: tensor<7x140x128xi32>) -> tensor<7x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<0> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x128xi32> + // CHECK: %[[result:.*]] = "mhlo.reshape"({{.*}}) : (tensor<7x1x128xi32>) -> tensor<7x128xi32> + // CHECK: return %[[result]] : tensor<7x128xi32> + return %2: tensor<7x128xi32> +} + +// CHECK-LABEL: func @matrix_diag_part_align_ll +func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "LEFT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + // CHECK: %[[false:.*]] = mhlo.constant dense : tensor + // CHECK: %[[b_false:.*]] = "mhlo.broadcast"(%[[false]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi1> + // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_false]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> + return %2: tensor<7x22x128xi32> +} + +// CHECK-LABEL: func @matrix_diag_part_align_lr +func @matrix_diag_part_align_lr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "LEFT_RIGHT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + // CHECK: %[[le:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = "LE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[le]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> + return %2: tensor<7x22x128xi32> +} + +// CHECK-LABEL: func @matrix_diag_part_align_rl +func @matrix_diag_part_align_rl(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + // CHECK: %[[ge:.*]] = chlo.broadcast_compare %{{[0-9]*}}, %{{[0-9]*}} {comparison_direction = "GE"} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[ge]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> + return %2: tensor<7x22x128xi32> +} + +// CHECK-LABEL: func @matrix_diag_part_align_rr +func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_RIGHT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + // CHECK: %[[true:.*]] = mhlo.constant dense : tensor + // CHECK: %[[b_true:.*]] = "mhlo.broadcast"(%[[true]]) {broadcast_sizes = dense<[1, 22, 128]> : tensor<3xi64>} : (tensor) -> tensor<1x22x128xi1> + // CHECK: %{{[0-9]*}} = "mhlo.select"(%[[b_true]], %{{[0-9]*}}, %{{[0-9]*}}) : (tensor<1x22x128xi1>, tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi32> + return %2: tensor<7x22x128xi32> +} + +// CHECK-LABEL: func @matrix_diag_part_align_7d +// CHECK: (%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> +func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> { + %0 = mhlo.constant dense<-1.> : tensor // padding value + %1 = mhlo.constant dense<[-6, -3]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = f32, align = "LEFT_RIGHT" + } : (tensor<3x5x7x9x11x13x17xf32>, tensor<2xi32>, tensor) -> tensor<3x5x7x9x11x4x10xf32> + return %2: tensor<3x5x7x9x11x4x10xf32> +} + //===----------------------------------------------------------------------===// // Einsum. //===----------------------------------------------------------------------===// @@ -958,7 +1116,7 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten // SparseMatMul where one operand needs to be transposed and the other one not. // -// CHECK-LABEL: func @test_sparse_mat_mul_with_transpose +// CHECK-LABEL: @test_sparse_mat_mul_with_transpose // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> // CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32> // CHECK-SAME: -> tensor<3x5xf32> @@ -968,7 +1126,6 @@ func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> ten // CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[TRANSPOSE]]) // CHECK-SAME: -> tensor<3x5xf32> // CHECK: return %[[RESULT]] -// CHECK: } func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> { %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32> return %0: tensor<3x5xf32> @@ -976,7 +1133,7 @@ func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5 // SparseMatMul where one operand needs to be casted and the other one not. // -// CHECK-LABEL: func @test_sparse_mat_mul_with_cast +// CHECK-LABEL: @test_sparse_mat_mul_with_cast // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> // CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16> // CHECK-SAME: -> tensor<3x5xf32> @@ -985,7 +1142,6 @@ func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5 // CHECK: %[[RESULT:.*]] = "mhlo.dot"(%[[ARG0]], %[[CAST]]) // CHECK-SAME: -> tensor<3x5xf32> // CHECK: return %[[RESULT]] -// CHECK: } func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> { %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32> return %0: tensor<3x5xf32> @@ -1485,7 +1641,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: %[[CASTED_MAX:.*]] = "mhlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] - // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex> + // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] // CHECK: %[[BCAST_MAX:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: %[[SHIFTED_INP:.*]] = mhlo.subtract %[[ARG0]], %[[BCAST_MAX]] // CHECK: %[[EXP:.*]] = "mhlo.exponential"(%[[SHIFTED_INP]]) @@ -1500,7 +1656,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: %[[CASTED_SUM:.*]] = "mhlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] - // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex> + // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] // CHECK: %[[BCAST_SUM:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: %[[RESULT:.*]] = mhlo.divide %[[EXP]], %[[BCAST_SUM]] // CHECK: return %[[RESULT]] @@ -1557,7 +1713,7 @@ func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: %[[CASTED_SUM:.*]] = "mhlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[LOG:.*]] = "mhlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] - // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex> + // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] // CHECK: %[[BCAST_SUM:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: %[[RESULT:.*]] = mhlo.subtract {{.*}}, %[[BCAST_SUM]] // CHECK: return %[[RESULT]] @@ -1693,6 +1849,48 @@ func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } +// CHECK-LABEL: @acos +// CHLO-LABEL: @acos +func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: "chlo.acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> +// CHLO: %[[VAL_1:.*]] = "mhlo.compare"({{.*}}) {comparison_direction = "NE"} +// CHLO: %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0 +// CHLO: %[[VAL_4:.*]] = mhlo.constant dense<1.000000e+00> +// CHLO: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_4]], %[[VAL_5]] +// CHLO: %[[VAL_7:.*]] = "mhlo.sqrt"(%[[VAL_6]]) +// CHLO: %[[VAL_8:.*]] = mhlo.constant dense<1.000000e+00> +// CHLO: %[[VAL_9:.*]] = mhlo.add %[[VAL_8]], %arg0 +// CHLO: %[[VAL_10:.*]] = mhlo.atan2 %[[VAL_7]], %[[VAL_9]] +// CHLO: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> +// CHLO: %[[VAL_11:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_10]] +// CHLO: %[[VAL_12:.*]] = mhlo.constant dense<3.14159274> +// CHLO: %[[VAL_13:.*]] = "mhlo.select"(%[[VAL_1]], %[[VAL_11]], %[[VAL_12]]) +// CHLO: return %[[VAL_13]] : tensor<2xf32> + %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// CHECK-LABEL: @acos_dynamic +// CHLO-LABEL: @acos_dynamic +func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "chlo.acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> +// CHLO: %[[VAL_1:.*]] = "mhlo.compare"({{.*}}) {comparison_direction = "NE"} +// CHLO: %[[VAL_5:.*]] = mhlo.multiply %arg0, %arg0 +// CHLO: %[[VAL_4:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} +// CHLO: %[[VAL_6:.*]] = mhlo.subtract %[[VAL_4]], %[[VAL_5]] +// CHLO: %[[VAL_7:.*]] = "mhlo.sqrt"(%[[VAL_6]]) +// CHLO: %[[VAL_8:.*]] = "chlo.constant_like"(%arg0) {value = 1.000000e+00 : f32} +// CHLO: %[[VAL_9:.*]] = mhlo.add %[[VAL_8]], %arg0 +// CHLO: %[[VAL_10:.*]] = mhlo.atan2 %[[VAL_7]], %[[VAL_9]] +// CHLO: %[[VAL_3:.*]] = "chlo.constant_like"(%arg0) {value = 2.000000e+00 : f32} +// CHLO: %[[VAL_11:.*]] = mhlo.multiply %[[VAL_3]], %[[VAL_10]] +// CHLO: %[[VAL_12:.*]] = "chlo.constant_like"(%arg0) {value = 3.14159274 : f32} +// CHLO: %[[VAL_13:.*]] = "mhlo.select"(%[[VAL_1]], %[[VAL_11]], %[[VAL_12]]) +// CHLO: return %[[VAL_13]] + %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // CHECK-LABEL: func @cast_dynamic_i2f func @cast_dynamic_i2f(%arg0: tensor) -> tensor { // CHECK: "mhlo.convert"(%arg0) : (tensor) -> tensor @@ -1900,7 +2098,7 @@ func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK-DAG: [[SCALAR:%.+]] = mhlo.constant dense<5.000000e-01> : tensor // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<2xf32> - // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] : tensor<1xindex> + // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] // CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<2xf32> // CHECK-DAG: [[R1:%.+]] = mhlo.multiply %arg0, [[HALF]] : tensor<2xf32> // CHECK-DAG: [[R2:%.+]] = "mhlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32> @@ -1922,7 +2120,7 @@ func @sigmoid_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-DAG: [[SCALAR:%.+]] = mhlo.constant dense<5.000000e-01> : tensor // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<*xf32> - // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] : tensor + // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] // CHECK-DAG: [[HALF:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor) -> tensor<*xf32> // CHECK-DAG: [[R1:%.+]] = mhlo.multiply %arg0, [[HALF]] : tensor<*xf32> // CHECK-DAG: [[R2:%.+]] = "mhlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32> @@ -2126,11 +2324,8 @@ func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { // CHECK-LABEL: func @sign // CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { - // CHECK: [[PRED:%.*]] = "mhlo.compare"([[ARG]], [[ARG]]) - // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> // CHECK: [[SIGN:%.*]] = "mhlo.sign"([[ARG]]) - // CHECK: [[SELECT:%.*]] = "mhlo.select"([[PRED]], [[ZEROS]], [[SIGN]]) - // CHECK: return [[SELECT]] : tensor<1x2x3x4xf32> + // CHECK: return [[SIGN]] : tensor<1x2x3x4xf32> %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>) return %0 : tensor<1x2x3x4xf32> } @@ -3029,6 +3224,34 @@ func @linspace_invalid_num(%arg0: tensor, %arg1: tensor) -> tensor } +//===----------------------------------------------------------------------===// +// LegacyCall op legalizations. +//===----------------------------------------------------------------------===// + +func @identity_func(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { + return %arg0: tensor<10x2xf32> +} + +// CHECK-LABEL: testSimpleLegacyCallOp +func @testSimpleLegacyCallOp(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { + // CHECK: %[[RESULT:.*]] = call @identity_func(%arg0) : (tensor<10x2xf32>) -> tensor<10x2xf32> + %0 = "tf.LegacyCall"(%arg0) {f = @identity_func} : (tensor<10x2xf32>) -> tensor<10x2xf32> + // CHECK: return %[[RESULT]] + return %0: tensor<10x2xf32> +} + +func @select_first(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { + return %arg0: tensor<10x2xf32> +} + +// CHECK-LABEL: testMultiInputLegacyCallOp +func @testMultiInputLegacyCallOp(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { + // CHECK: %[[RESULT:.*]] = call @select_first(%arg0, %arg1) : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> + %0 = "tf.LegacyCall"(%arg0, %arg1) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @select_first} : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> + // CHECK: return %[[RESULT]] + return %0: tensor<10x2xf32> +} + //===----------------------------------------------------------------------===// // Conv op legalizations. //===----------------------------------------------------------------------===// @@ -3277,8 +3500,8 @@ func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { // tf.Size legalization //===----------------------------------------------------------------------===// -// CHECK-LABEL: @size_rank_one_i32 -func @size_rank_one_i32(%input: tensor) -> (tensor) { +// CHECK-LABEL: @size_scalar_i32 +func @size_scalar_i32(%input: tensor) -> (tensor) { // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> // CHECK-SAME: tensor %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor) -> tensor @@ -3286,8 +3509,8 @@ func @size_rank_one_i32(%input: tensor) -> (tensor) { return %size : tensor } -// CHECK-LABEL: @size_rank_one_i64 -func @size_rank_one_i64(%input: tensor) -> (tensor) { +// CHECK-LABEL: @size_scalar_i64 +func @size_scalar_i64(%input: tensor) -> (tensor) { // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> // CHECK-SAME: tensor %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT64"} : (tensor) -> tensor @@ -3295,19 +3518,40 @@ func @size_rank_one_i64(%input: tensor) -> (tensor) { return %size : tensor } +// CHECK-LABEL: @size_rank_one_i64 +// CHECK-SAME: (%[[INPUT:.*]]: tensor) +func @size_rank_one_i64(%input: tensor) -> (tensor) { + // CHECK: %[[INIT:.*]] = mhlo.constant dense<1> + // CHECK-SAME: tensor + + // CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) + // CHECK-SAME: dimension = 0 + // CHECK-SAME: tensor + + // CHECK: %[[CAST_DIM_0:.*]] = "mhlo.convert"(%[[DIM_0]]) : (tensor) -> tensor + // CHECK: %[[RESULT:.*]] = chlo.broadcast_multiply %[[INIT]], %[[CAST_DIM_0]] + + %size = "tf.Size"(%input) : (tensor) -> tensor + // CHECK: return %[[RESULT]] + return %size : tensor +} + // CHECK-LABEL: @size_ranked // CHECK-SAME: (%[[INPUT:.*]]: tensor<2x?x8xf32>) func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor) { // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> // CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 0 - // CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[DIM_0]] + // CHECK: %[[CAST_DIM_0:.*]] = "mhlo.convert"(%[[DIM_0]]) : (tensor) -> tensor + // CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[CAST_DIM_0]] // CHECK: %[[DIM_1:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 1 - // CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]] + // CHECK: %[[CAST_DIM_1:.*]] = "mhlo.convert"(%[[DIM_1]]) : (tensor) -> tensor + // CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[CAST_DIM_1]] // CHECK: %[[DIM_2:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 2 - // CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]] + // CHECK: %[[CAST_DIM_2:.*]] = "mhlo.convert"(%[[DIM_2]]) : (tensor) -> tensor + // CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[CAST_DIM_2]] %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor // CHECK: return %[[MUL_2]] return %size : tensor @@ -3846,36 +4090,167 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // tf.AvgPool legalization //===----------------------------------------------------------------------===// -// CHECK-LABEL: avgpool_valid_padding -// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x20x7xf16> -func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> { - // CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x20x7xf16>) -> tensor<2x12x20x7xf32> - // CHECK: [[INIT:%.+]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.+]] = "mhlo.reduce_window"([[CONV32]], [[INIT]]) ( { - // CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): - // CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] - // CHECK: "mhlo.return"([[ADD]]) - // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor) -> tensor<2x3x5x7xf32> - // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor - // CHECK: [[DIV:%.+]] = chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> - // CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> - // CHECK: return [[CONV16]] - %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> +// CHECK-LABEL: @avgpool_valid_padding +// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16> +// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: "mhlo.return"([[ADD]]) +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> +// CHECK-SAME: -> tensor<2x3x5x7xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: -> tensor<2x3x5x7xf32> +// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) +// CHECK-SAME: -> tensor<2x3x5x7xf16> +// CHECK: return [[CONV16]] +func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> { + %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> return %0 : tensor<2x3x5x7xf16> } -// CHECK-LABEL: avgpool_same_padding -func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> { - // CHECK: tf.AvgPool - %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> - return %0 : tensor<2x4x7x7xf32> +// CHECK-LABEL: @avgpool_3d_valid_padding +// CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16> +// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: "mhlo.return"([[ADD]]) +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> +// CHECK-SAME: -> tensor<2x4x3x5x7xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: -> tensor<2x4x3x5x7xf32> +// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) +// CHECK-SAME: -> tensor<2x4x3x5x7xf16> +// CHECK: return [[CONV16]] +func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> { + %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> + return %0 : tensor<2x4x3x5x7xf16> +} + +// CHECK-LABEL: @avgpool_nchw_format +// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16> +// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: "mhlo.return"([[ADD]]) +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]> +// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> +// CHECK-SAME: -> tensor<2x7x3x5xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: -> tensor<2x7x3x5xf32> +// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) +// CHECK-SAME: -> tensor<2x7x3x5xf16> +// CHECK: return [[CONV16]] +func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> { + %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> + return %0 : tensor<2x7x3x5xf16> +} + +// CHECK-LABEL: @avgpool_3d_ncdhw_format +// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16> +// CHECK: [[CONV32:%.+]] = "mhlo.convert"(%arg0) : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32> +// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) ( { +// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): +// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] +// CHECK: "mhlo.return"([[ADD]]) +// CHECK: }) +// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]> +// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> +// CHECK-SAME: -> tensor<2x7x4x3x5xf32> +// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor +// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] +// CHECK-SAME: broadcast_dimensions = dense<> +// CHECK-SAME: -> tensor<2x7x4x3x5xf32> +// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV_RESULT]]) +// CHECK-SAME: -> tensor<2x7x4x3x5xf16> +// CHECK: return [[CONV16]] +func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> { + %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> + return %0 : tensor<2x7x4x3x5xf16> +} + +// CHECK-LABEL: @avgpool_same_padding( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> +// CHECK-SAME: -> tensor<2x4x6x7xf32> +// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> +// CHECK-SAME: -> tensor<2x4x6x7xf32> +// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32> +// CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32> +// CHECK: } +func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> { + %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> + return %0 : tensor<2x4x6x7xf32> +} + +// CHECK-LABEL: @avgpool_3d_same_padding( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> +// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): +// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: "mhlo.return"(%[[SUM1]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> +// CHECK-SAME: -> tensor<2x4x4x6x7xf32> +// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32> +// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) ( { +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor +// CHECK: "mhlo.return"(%[[SUM2]]) : (tensor) -> () +// CHECK: }) +// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> +// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> +// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> +// CHECK-SAME: -> tensor<2x4x4x6x7xf32> +// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] +// CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32> +// CHECK: } +func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> { + %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> + return %0 : tensor<2x4x4x6x7xf32> } //===----------------------------------------------------------------------===// // AvgPoolGrad op legalizations. //===----------------------------------------------------------------------===// -// CHECK-LABEL: func @avgpool_grad_valid_padding( +// CHECK-LABEL: @avgpool_grad_valid_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor @@ -3907,7 +4282,7 @@ func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24 return %result : tensor<10x24x32x64xf32> } -// CHECK-LABEL: func @avgpool_3d_grad_valid_padding( +// CHECK-LABEL: @avgpool_3d_grad_valid_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor @@ -3936,7 +4311,7 @@ func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor< return %result : tensor<10x8x24x32x64xf32> } -// CHECK-LABEL: func @avgpool_grad_same_padding( +// CHECK-LABEL: @avgpool_grad_same_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> @@ -3975,7 +4350,7 @@ func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9x return %result : tensor<2x13x25x9xf32> } -// CHECK-LABEL: func @avgpool_3d_grad_same_padding( +// CHECK-LABEL: @avgpool_3d_grad_same_padding( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> @@ -4013,7 +4388,7 @@ func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x1 return %result : tensor<2x8x13x25x9xf32> } -// CHECK-LABEL: func @avgpool_grad_nchw_format( +// CHECK-LABEL: @avgpool_grad_nchw_format( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> @@ -4052,7 +4427,7 @@ func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf return %result : tensor<2x9x13x25xf32> } -// CHECK-LABEL: func @avgpool_3d_grad_ncdwh_format( +// CHECK-LABEL: @avgpool_3d_grad_ncdwh_format( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> @@ -4090,7 +4465,7 @@ func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8 return %result : tensor<2x9x8x13x25xf32> } -// CHECK-LABEL: func @avgpool_grad_bf16( +// CHECK-LABEL: @avgpool_grad_bf16( // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor @@ -4227,21 +4602,65 @@ func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { } // CHECK-LABEL: func @cumsum_exclusive +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: "tf.Cumsum" + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[X]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: "mhlo.return"([[SUM]]) : (tensor) -> () + // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: return [[CONVERT_REDUCE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> return %1 : tensor<4xf32> } // CHECK-LABEL: func @cumsum_reverse +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: "tf.Cumsum" + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: "mhlo.return"([[SUM]]) : (tensor) -> () + // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[REDUCE]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: return [[REVERSE_BACK]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> return %1 : tensor<4xf32> } +// CHECK-LABEL: func @cumsum_exclusive_reverse +// CHECK-SAME: [[X:%.*]]: tensor<4xf32> +func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor + // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[CONVERT_X:%.*]] = "mhlo.convert"([[REVERSE1]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) ( { + // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): + // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor + // CHECK: "mhlo.return"([[SUM]]) : (tensor) -> () + // CHECK: }) {padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) {edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: [[CONVERT_REDUCE:%.*]] = "mhlo.convert"([[PAD]]) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: return [[REVERSE_BACK]] + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + // CHECK-LABEL: func @cumsum_dynamic func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "tf.Cumsum" @@ -4249,6 +4668,10 @@ func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor return %0 : tensor } +//===----------------------------------------------------------------------===// +// Qr op legalization +//===----------------------------------------------------------------------===// + // CHECK: func @qr([[VAL_0:%.*]]: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) { // The tf.Qr lowering is a full algorithm that is not effective to verify with diff --git a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir index 57959568287..1032bb723c5 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/case.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/case.mlir @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s +// RUN: tf-mlir-translate -split-input-file -mlir-hlo-to-hlo-text %s | FILECHECK_OPTS="" FileCheck %s func @main() -> tensor { %cst = constant {name = "constant"} dense<1> : tensor diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 9929bd85b43..316eda4c4aa 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -1087,3 +1087,15 @@ func @main(%arg: tensor<3x4xf32>, %token: !mhlo.token) -> !mhlo.token { } // CHECK-NOT: frontend_attributes + +// ----- + +// Checks exporting rng-bit-generator. + +// CHECK: HloModule +func @main(%arg: tensor<3xui64>) -> tuple, tensor<2x2xui32>> { +// CHECK: %[[ARG0:.*]] = u64[3] parameter(0) +// CHECK: ROOT %[[RESULT:.*]] = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %[[ARG0]]), algorithm=rng_philox + %0 = "mhlo.rng_bit_generator"(%arg) {rng_algorithm = 2 : i32} : (tensor<3xui64>) -> tuple, tensor<2x2xui32>> + return %0 : tuple, tensor<2x2xui32>> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fusion.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fusion.hlotxt new file mode 100644 index 00000000000..dc2ce6d58f8 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/fusion.hlotxt @@ -0,0 +1,35 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.17 + +// CHECK: func @main(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> tensor { +// CHECK: %0 = "mhlo.fusion"(%[[ARG0:.*]], %[[ARG1:.*]]) ( { +// CHECK: ^bb0(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): +// CHECK: }) {fusion_kind = "kLoop"} : (tensor, tensor) -> tensor +// CHECK: %1 = "mhlo.fusion"(%[[ARG0:.*]], %[[ARG1:.*]]) ( { +// CHECK: ^bb0(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor): +// CHECK: }) {fusion_kind = "kLoop"} : (tensor, tensor) -> tuple, tensor> +// CHECK: } + +%region_0.3 (Arg_0.4: f32[], Arg_1.5: f32[]) -> f32[] { + %Arg_0.4 = f32[] parameter(0) + %Arg_1.5 = f32[] parameter(1) + ROOT %add.6 = f32[] add(f32[] %Arg_0.4, f32[] %Arg_1.5) +} + +%region_1.8 (Arg_0.9: f32[], Arg_1.10: f32[]) -> (f32[], f32[]) { + %Arg_0.9 = f32[] parameter(0) + %Arg_1.10 = f32[] parameter(1) + %add.11 = f32[] add(f32[] %Arg_0.9, f32[] %Arg_1.10) + %subtract.12 = f32[] subtract(f32[] %Arg_0.9, f32[] %Arg_1.10) + ROOT %tuple.13 = (f32[], f32[]) tuple(f32[] %add.11, f32[] %subtract.12) +} + +ENTRY %main.17 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[] { + %Arg_0.1 = f32[] parameter(0) + %Arg_1.2 = f32[] parameter(1) + %fusion.7 = f32[] fusion(f32[] %Arg_0.1, f32[] %Arg_1.2), kind=kLoop, calls=%region_0.3 + %fusion.14 = (f32[], f32[]) fusion(f32[] %Arg_0.1, f32[] %Arg_1.2), kind=kLoop, calls=%region_1.8 + %get-tuple-element.15 = f32[] get-tuple-element((f32[], f32[]) %fusion.14), index=0 + ROOT %get-tuple-element.16 = f32[] get-tuple-element((f32[], f32[]) %fusion.14), index=1 +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fusion.mlir b/tensorflow/compiler/mlir/xla/tests/translate/fusion.mlir new file mode 100644 index 00000000000..7da9b7c5f7b --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/fusion.mlir @@ -0,0 +1,27 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s + +// CHECK: %[[REGION0:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> f32[] +// CHECK: %[[REGION1:.*]] ({{.*}}: f32[], {{.*}}: f32[]) -> (f32[], f32[]) +// +// CHECK: ENTRY +// CHECK: %[[PARAM0:.*]] = f32[] parameter(0) +// CHECK: %[[PARAM1:.*]] = f32[] parameter(1) +// CHECK: %[[FUSION0:.*]] = f32[] fusion(f32[] %[[PARAM0]], f32[] %[[PARAM1]]), kind=kLoop, calls=%[[REGION0]] +// CHECK: %[[FUSION1:.*]] = (f32[], f32[]) fusion(f32[] %[[PARAM0]], f32[] %[[PARAM1]]), kind=kLoop, calls=%[[REGION1]] +// CHECK: f32[] get-tuple-element((f32[], f32[]) %[[FUSION1]]), index=0 +// CHECK: f32[] get-tuple-element((f32[], f32[]) %[[FUSION1]]), index=1 +// CHECK: } +func @main(%arg0: tensor, %arg1: tensor) { + %result = "mhlo.fusion"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %result = "mhlo.add"(%arg2, %arg3): (tensor, tensor) -> tensor + "mhlo.return"(%result) : (tensor) -> () + }) { fusion_kind = "kLoop" } : (tensor, tensor) -> tensor + %result0, %result1 = "mhlo.fusion"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %elem0 = "mhlo.add"(%arg2, %arg3): (tensor, tensor) -> tensor + %elem1 = "mhlo.subtract"(%arg2, %arg3): (tensor, tensor) -> tensor + "mhlo.return"(%elem0, %elem1) : (tensor, tensor) -> () + }) { fusion_kind="kLoop" } : (tensor, tensor) -> (tensor, tensor) + return +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt index 2b7d44f4522..4d4e0213da8 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/translate/import.hlotxt @@ -1,4 +1,4 @@ -// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -DPRIVATE="attributes {sym_visibility = \"private\"}" +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FILECHECK_OPTS="" FileCheck %s -DPRIVATE="attributes {sym_visibility = \"private\"}" HloModule main @@ -1005,3 +1005,12 @@ add { // CHECK: "mhlo.not"(%[[ARG0]]) {name = "{{.*}}"} : (tensor<4xui16>) -> tensor<4xui16> ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1) } + +// CHECK-LABEL: func @rngbitgen +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xui64>) +%rngbitgen (Arg_0.1: u64[3]) -> (u64[3], u32[2,2]) { + %Arg_0.1 = u64[3] parameter(0) + // CHECK: "mhlo.rng_bit_generator"(%[[ARG0]]) {rng_algorithm = 2 : i32} : (tensor<3xui64>) -> tuple, tensor<2x2xui32>> + ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/non_isolated_computation.mlir b/tensorflow/compiler/mlir/xla/tests/translate/non_isolated_computation.mlir new file mode 100644 index 00000000000..94f53ebbfcb --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/non_isolated_computation.mlir @@ -0,0 +1,16 @@ +// RUN: not tf-mlir-translate -mlir-hlo-to-hlo-text %s 2>&1 | FileCheck %s + +func @main(%arg0: tensor) -> tensor { + %c0 = mhlo.constant dense<1> : tensor + %0 = "mhlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + // CHECK: requires all operands to be defined in the parent region for export + %1 = "mhlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + %2 = mhlo.add %arg1, %arg1 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) : (tensor) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 23f11cef4d9..5fe933ee635 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -71,9 +71,14 @@ class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; LegalizeTF(const LegalizeTF &) {} - explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo) { + explicit LegalizeTF(bool allow_partial_conversion, bool legalize_chlo, + llvm::Optional tf2xla_fallback_device_type) { allow_partial_conversion_ = allow_partial_conversion; legalize_chlo_ = legalize_chlo; + use_tf2xla_fallback_ = tf2xla_fallback_device_type.hasValue(); + if (tf2xla_fallback_device_type.hasValue()) { + device_type_ = tf2xla_fallback_device_type.getValue().str(); + } } /// Performs the lowering to XLA dialect. @@ -89,6 +94,17 @@ class LegalizeTF : public PassWrapper { llvm::cl::desc( "Also legalizes intermediate chlo ops to hlo (default true)"), llvm::cl::init(true)}; + Option use_tf2xla_fallback_{ + *this, "use-tf2xla-fallback", + llvm::cl::desc( + "Also use TF2XLA fallback for legalization (default false)"), + llvm::cl::init(false)}; + Option device_type_{ + *this, "device-type", + llvm::cl::desc( + "The device type used by TF2XLA fallback. Must be specified if " + "use-tf2xla-fallback is true, otherwise not used."), + llvm::cl::init("INVALID_DEVICE_TYPE")}; }; /// Returns if the given TF data format string is the default format. @@ -365,7 +381,7 @@ static Value UpdateSliceInMinorDims(Location loc, Value v, Value update, ArrayRef minor_starts, OpBuilder *builder) { llvm::SmallVector dus_starts(minor_starts.size()); - for (int64_t i = 0; i < minor_starts.size(); ++i) { + for (uint64_t i = 0; i < minor_starts.size(); ++i) { dus_starts[i] = GetScalarConstOfType(builder->getIntegerType(32), loc, minor_starts[i], builder); } @@ -808,7 +824,7 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( values.reserve(shaped_type.getNumElements() / shape[1]); for (auto it : llvm::enumerate(int_attr.getIntValues())) { - if (it.index() % shape[1] == column) { + if (static_cast(it.index() % shape[1]) == column) { values.push_back(it.value().getSExtValue()); } } @@ -896,7 +912,7 @@ static DenseElementsAttr GetEpsilonValue(Type ty) { auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon)); return DenseElementsAttr::get(scalar_ty, value); } else if (element_ty.isBF16()) { - uint16_t raw_epsilon = tensorflow::bfloat16::epsilon().value; + uint16_t raw_epsilon = Eigen::NumTraits::epsilon().value; auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon)); return DenseElementsAttr::get(scalar_ty, value); } else if (element_ty.isF32()) { @@ -1387,6 +1403,269 @@ class ConvertDiagPartOp : public OpRewritePattern { } }; +// Converts TensorFlow MatrixDiagPartOp to HLO ops. +class ConvertMatrixDiagPartV3Op + : public OpRewritePattern { + using Shape = llvm::SmallVector; + + // Parse the "k" parameter. MatrixDiagPartV3 allows to specify the diagonal(s) + // with k. This can be either a single value (for a single diagonal) or a + // tuple of two values (starting and ending diagonal, for a band). + LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const { + DenseIntElementsAttr kattr; + if (!matchPattern(op.k(), m_Constant(&kattr))) { + return failure(); + } + DenseIntElementsAttr::iterator it = kattr.begin(); + (*k)[0] = (*it).getSExtValue(); + it++; + if (it == kattr.end()) { + // Handle input like e.g. "k = 5", in which case we extract a single + // diagonal. + (*k)[1] = (*k)[0]; + } else { + // Handle input like e.g. "k = [-1, 1]", in which case we extract a + // band (multiple diagonals). + (*k)[1] = (*it).getSExtValue(); + } + return success(); + } + + // Utility method for broadcasting integer constants to a given shape. + BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant, + int int_size, PatternRewriter &rewriter) const { + return rewriter.create( + loc, RankedTensorType::get(shape, rewriter.getIntegerType(int_size)), + GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant, + &rewriter), + GetI64ElementsAttr(shape, &rewriter)); + } + + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + ShapedType input_type = op.input().getType().dyn_cast(); + auto element_type = input_type.getElementType(); + + // Align is a string specifying how superdiagonals and subdiagonals should + // be aligned/padded for diagonals that are shorter than max_diag_len. The + // format is "{super}_{sub}", with {super} the superdiagonal alignment and + // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the + // left, "RIGHT" means rows will be padded ot the right. The default is + // "RIGHT_LEFT". + StringRef align = op.getAttrOfType("align").getValue(); + enum Alignment { kLeft, kRight }; + + // default is RIGHT_LEFT + Alignment superdiagonal_align = kRight; + Alignment subdiagonal_align = kLeft; + + if (align == "RIGHT_LEFT") { + superdiagonal_align = kRight; + subdiagonal_align = kLeft; + } else if (align == "RIGHT_RIGHT") { + superdiagonal_align = kRight; + subdiagonal_align = kRight; + } else if (align == "LEFT_RIGHT") { + superdiagonal_align = kLeft; + subdiagonal_align = kRight; + } else if (align == "LEFT_LEFT") { + superdiagonal_align = kLeft; + subdiagonal_align = kLeft; + } else { + return failure(); // unsupported alignment + } + + // MatrixDiagPart operates on a matrix of shape [I, J, ..., L, M, N], and + // will extract the diagonal(s) out of [M, N], for all [I, J, ..., L]. + if (!input_type || !input_type.hasStaticShape()) return failure(); + int64_t num_dims = input_type.getRank(); + if (num_dims < 2) return failure(); + int64_t rows = input_type.getDimSize(num_dims - 2); // rows + int64_t cols = input_type.getDimSize(num_dims - 1); // cols + + // We extract the diagonals from k[0] up to and including k[1]. + // Addressing is 0 for the main diagonal. (So k = [0, 0] would just extract + // the main diagonal). It's negative for subdiagonals (under and to the left + // of the main diagonal) and positive for superdiagonals (above and to the + // right of the main diagonal). + int64_t k[2]; + if (failed(ExtractK(op, &k))) return failure(); + int num_diags = k[1] - k[0] + 1; + + // Shifting diagonals away from the main diagonal might shorten them. This + // is the longest diagonal we will see. We make this the last dimension of + // the output shape. + int64_t max_diag_len = + std::min(rows + std::min(k[1], static_cast(0)), + cols + std::min(-k[0], static_cast(0))); + + // The first dimension is the index vector dimension we'll use for gather. + // It's 1 here, but will be 2 once we glue x and y together. + Shape indices_shape({1, num_diags, max_diag_len}); + + RankedTensorType iota_type = + RankedTensorType::get(indices_shape, rewriter.getIntegerType(32)); + Value iotaM = + rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(1)); + Value iotaN = + rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(2)); + + // Boradcasted constants, of the same shape as iotaM and iotaN. + Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter); + Value b_false = BroadcastConstant(loc, indices_shape, 0, 1, rewriter); + Value b_true = BroadcastConstant(loc, indices_shape, 1, 1, rewriter); + Value b_k1 = BroadcastConstant(loc, indices_shape, k[1], 32, rewriter); + Value b_rows = BroadcastConstant(loc, indices_shape, rows, 32, rewriter); + Value b_cols = BroadcastConstant(loc, indices_shape, cols, 32, rewriter); + Value b_max_diag_len = + BroadcastConstant(loc, indices_shape, max_diag_len, 32, rewriter); + + // d = k[1] - m + // (A.k.a. the number of the diagonal, depending on m. Note that we + // subtract m here. This means we start with the superdiagonals and + // move downwards towards the subdiagonals. So the start indices will + // be decreasing.) + Value d = rewriter.create(loc, b_k1, iotaM); + Value neg_d = rewriter.create(loc, d); + + // diag_len_d = min(rows + min(d, 0), cols - max(d, 0)) + // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.) + Value diag_len_d = rewriter.create( + loc, + rewriter.create(loc, b_rows, + rewriter.create(loc, d, b_zero)), + rewriter.create(loc, b_cols, + rewriter.create(loc, d, b_zero))); + + // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise. + Value cmp; + if (subdiagonal_align == kRight && superdiagonal_align == kRight) { + cmp = b_true; + } else if (superdiagonal_align == kRight) { + // offset = d>=0 ? max_diag_len - diag_len_d : 0 + cmp = rewriter.create(loc, d, b_zero); + } else if (subdiagonal_align == kRight) { + // offset = d<=0 ? max_diag_len - diag_len_d : 0 + cmp = rewriter.create(loc, d, b_zero); + } else { + // offset = 0 + cmp = b_false; + } + + // This offset shifts the diagonals to the "left" or "right", depending + // on alignment. + Value offset = rewriter.create( + loc, b_zero.getType(), cmp, + rewriter.create(loc, b_max_diag_len, diag_len_d), b_zero); + + // x = max(d, 0) - offset + // y = max(-d, 0) - offset + Value x = rewriter.create( + loc, rewriter.create(loc, d, b_zero), offset); + Value y = rewriter.create( + loc, rewriter.create(loc, neg_d, b_zero), offset); + + Value n_plus_x = rewriter.create(loc, iotaN, x); + Value n_plus_y = rewriter.create(loc, iotaN, y); + + // GatherOp is happy about letting us index out of bounds values, but those + // values will be undefined. So we mask them later. Set up the boolean + // expression that tells us which entries, in the output shape, are out of + // bounds and thus become the padding_value. + Value x_in_bounds = rewriter.create( + loc, + rewriter.create(loc, b_false.getType(), n_plus_x, + b_zero), + rewriter.create(loc, b_false.getType(), n_plus_x, b_cols)); + Value y_in_bounds = rewriter.create( + loc, + rewriter.create(loc, b_false.getType(), n_plus_y, + b_zero), + rewriter.create(loc, b_false.getType(), n_plus_y, b_rows)); + Value in_bounds = rewriter.create( + loc, + RankedTensorType::get(Shape({num_diags, max_diag_len}), + rewriter.getIntegerType(1)), + rewriter.create(loc, x_in_bounds, y_in_bounds)); + + // Now combine x and y into the index data structure needed for gather. + Shape concat_shape({2, num_diags, max_diag_len}); + Value start_indices = rewriter.create( + loc, RankedTensorType::get(concat_shape, rewriter.getIntegerType(32)), + mlir::ValueRange({n_plus_y, n_plus_x}), + mlir::IntegerAttr::get(rewriter.getIntegerType(64), 0)); + + // Shape of the final output. (Except for dimension folding in the + // single diagonal case.) + Shape output_shape; + for (int i = 0; i < num_dims - 2; i++) { + output_shape.push_back(input_type.getDimSize(i)); + } + output_shape.push_back(num_diags); + output_shape.push_back(max_diag_len); + auto output_type = RankedTensorType::get(output_shape, element_type); + + // A slice is the shape of what GatherOp copies per lookup. So the last + // two dimensions (M, N in the matrix-diag-part docs) are where we go + // through entry by entry. + ArrayRef input_shape = input_type.getShape(); + Shape slice_sizes(input_shape.begin(), input_shape.end()); + int slice_dimensions = slice_sizes.size(); + slice_sizes[slice_dimensions - 2] = 1; + slice_sizes[slice_dimensions - 1] = 1; + + // Dimensions of the input we won't see in the output (M and N). + SmallVector collapsed_dims( + {slice_dimensions - 2, slice_dimensions - 1}); + + // Which dimensions (in the input) the two offset "columns" map to. + SmallVector start_index_map({num_dims - 2, num_dims - 1}); + + // Gather the diagonal entries. + // TODO(kramm): For a single diagonal, this might be slower than the + // mask + sum approach. Special-case num_diags==1? + auto dims_attr = GatherDimensionNumbers::get( + /*offset_dims=*/GetI64ElementsAttrForSeq(0, num_dims - 2, &rewriter), + /*collapsed_slice_dims=*/GetI64ElementsAttr(collapsed_dims, &rewriter), + /*start_index_map=*/GetI64ElementsAttr(start_index_map, &rewriter), + /*index_vector_dim=*/rewriter.getI64IntegerAttr(0), + rewriter.getContext()); + Value gather = rewriter.create( + loc, output_type, op.input(), start_indices, dims_attr, + GetI64ElementsAttr(slice_sizes, &rewriter)); + + // We now need to broadcast the "in_bounds" boolean expression, as well as + // the padding value, to do the final select. + Shape broadcast_bounds; + for (int i = 0; i < output_shape.size() - 2; i++) { + broadcast_bounds.push_back(output_shape[i]); + } + Value b_in_bounds = rewriter.create( + loc, RankedTensorType::get(output_shape, rewriter.getIntegerType(1)), + in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter)); + Value b_padding = rewriter.create( + loc, output_type, op.padding_value(), + GetI64ElementsAttr(output_shape, &rewriter)); + + // Replace all out-of-bounds values in the result with padding_value. + Value result = rewriter.create(loc, output_type, b_in_bounds, + gather, b_padding); + + if (num_diags == 1) { + // matrix_diag_part folds away the 1-sized band dimension if we only + // extract a single diagonal. + result = rewriter.create(loc, op.getType(), result); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + // Converts TensorFlow EinsumOp to either HLO EinsumOp or UnaryEinsumOp // depending on arity of the op. class ConvertEinsumOp : public OpRewritePattern { @@ -1531,23 +1810,23 @@ using ConvertFusedBatchNormGradV3Op = // Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or // HLO BatchNormInferenceOp, depending on the value of the 'is_training' // parameter. -class ConvertFusedBatchNormV3Op - : public OpRewritePattern { +template +class ConvertFusedBatchNormBase : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::FusedBatchNormV3Op op, + LogicalResult matchAndRewrite(FusedBatchNormOpT op, PatternRewriter &rewriter) const override { auto feature_dim = getFeatureDimensionAttr(rewriter, op.data_formatAttr(), op.x()); - auto input_type_tensor = op.x().getType().cast(); + auto input_type_tensor = op.x().getType().template cast(); auto input_element_type = input_type_tensor.getElementType(); - auto scale_type_tensor = op.scale().getType().cast(); + auto scale_type_tensor = op.scale().getType().template cast(); auto scale_element_type = scale_type_tensor.getElementType(); - auto mean_type_tensor = op.mean().getType().cast(); + auto mean_type_tensor = op.mean().getType().template cast(); auto mean_element_type = mean_type_tensor.getElementType(); // In the training case, dimensions of input tensors must be static. if (op.is_training() && (!input_type_tensor.hasStaticShape() || @@ -1561,7 +1840,7 @@ class ConvertFusedBatchNormV3Op Value bn_train_input = rewriter.create(op.getLoc(), op.x(), scale_element_type); TensorType bn_train_input_type_tensor = - bn_train_input.getType().cast(); + bn_train_input.getType().template cast(); if (op.is_training()) { // Training case. @@ -1643,17 +1922,25 @@ class ConvertFusedBatchNormV3Op /*broadcast_dimensions=*/DenseIntElementsAttr()); } - // TF FusedBatchNormV3 op expects 5 outputs. Outputs 3 and 4 are - // currently marked as "reserved spaces 1 and 2". They are used to - // pass the per-batch mean and variance to the gradiant. Here we - // maintain the same behavior by setting them to the mean and variance - // calculated by BatchNormTraining. Output 5 is unused; it doesn't - // matter what we pass there. - rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, - /*batch_variance=*/corrected_variance, - /*reserve_space_1=*/reserve_space_1, - /*reserve_space_2=*/batch_variance, - /*reserve_space_3=*/op.x()}); + if (std::is_same::value) { + // FusedBatchNormV2 expects 4 outputs. + // Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2". + // They are used to pass the per-batch mean and variance to the + // gradiant. Here we maintain the same behavior by setting them to the + // mean and variance calculated by BatchNormTraining. + rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, + /*batch_variance=*/corrected_variance, + /*reserve_space_1=*/reserve_space_1, + /*reserve_space_2=*/batch_variance}); + } else { // TF::FusedBatchNormV3Op + // FusedBatchNormV3 expects a 5th output, but the output is unused; it + // doesn't matter what we pass there. + rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, + /*batch_variance=*/corrected_variance, + /*reserve_space_1=*/reserve_space_1, + /*reserve_space_2=*/batch_variance, + /*reserve_space_3=*/op.x()}); + } } else { // Inference case. auto bn_train_op = rewriter.create( op.getLoc(), @@ -1670,31 +1957,45 @@ class ConvertFusedBatchNormV3Op // not used for inference. It doesn't matter what values we provide for // the last 5 results as long as they are of the same type. Forward // input mean and variance to output mean, variance, reserved_space_1 and - // reserver_space_2. Create a constant tensor to forward to last - // reserve_space_3 output. - auto reserve_space_3_type = op.getResult(5).getType().cast(); - int num_elements = reserve_space_3_type.hasStaticShape() - ? reserve_space_3_type.getNumElements() - : 0; - auto const_attr_type = RankedTensorType::get( - {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); - - Value dummy_const = rewriter.create( - op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); - if (const_attr_type != reserve_space_3_type) - dummy_const = rewriter.create( - op.getLoc(), reserve_space_3_type, dummy_const); - rewriter.replaceOp(op, {/*y=*/y_out, - /*batch_mean=*/op.mean(), - /*batch_variance=*/op.variance(), - /*reserve_space_1=*/op.mean(), - /*reserve_space_2=*/op.variance(), - /*reserve_space_3=*/dummy_const}); + // reserved_space_2. + if (std::is_same::value) { + rewriter.replaceOp(op, {/*y=*/y_out, + /*batch_mean=*/op.mean(), + /*batch_variance=*/op.variance(), + /*reserve_space_1=*/op.mean(), + /*reserve_space_2=*/op.variance()}); + } else { + // For FusedBatchNormV3Op, also create a constant tensor to forward to + // last reserve_space_3 output. + auto reserve_space_3_type = + op.getResult(5).getType().template cast(); + int num_elements = reserve_space_3_type.hasStaticShape() + ? reserve_space_3_type.getNumElements() + : 0; + auto const_attr_type = RankedTensorType::get( + {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); + Value dummy_const = rewriter.create( + op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); + if (const_attr_type != reserve_space_3_type) + dummy_const = rewriter.create( + op.getLoc(), reserve_space_3_type, dummy_const); + rewriter.replaceOp(op, {/*y=*/y_out, + /*batch_mean=*/op.mean(), + /*batch_variance=*/op.variance(), + /*reserve_space_1=*/op.mean(), + /*reserve_space_2=*/op.variance(), + /*reserve_space_3=*/dummy_const}); + } } return success(); } }; +using ConvertFusedBatchNormV2Op = + ConvertFusedBatchNormBase; +using ConvertFusedBatchNormV3Op = + ConvertFusedBatchNormBase; + using PaddingArray = std::vector>; @@ -1748,37 +2049,102 @@ static DenseIntElementsAttr GetReduceWindowPaddingAsAttr( flatten_paddings); } +// Helper function for dividing each entry of `pooled` by the count of its +// corresponding window, i.e., the number of non-padding entries of the window +// which an `AvgPool` operation performed on an `input_shape`-tensor would map +// to this entry, depending on `ksize` and `strides`. This function is used for +// `AvgPool` and `AvgPoolGrad` legalizations. +// `zero` is passed as a parameter because it can be reused from caller level. +// `pooled` must have `RankedTensorType`. +template +Operation *AvgPoolDivideByCount( + Value pooled, const SmallVector &input_shape, + const SmallVector &ksize, + const SmallVector &strides, OpTy op, Value zero, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); + RankedTensorType pooled_type = + pooled.getType().template cast(); + Type element_type = pooled_type.getElementType(); + Operation *result = nullptr; + RankedTensorType orig_input_type = + RankedTensorType::get(input_shape, element_type); + + if (op.padding() == "VALID") { + // All window counts are equal here because we don't have padding + // (each entry of `pooled` corresponds to a window that consists of + // original input entries only). + int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1, + std::multiplies()); + // Divide `pooled` by window counts. + Value divisor = + GetScalarConstOfType(element_type, loc, window_count, &rewriter); + auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); + result = rewriter.create( + loc, pooled_type, pooled, divisor, scalar_broadcast_dims); + } else { + assert(op.padding() == "SAME"); + // For SAME padding, only original entries that contributed to a window + // are counted for the average of this window, not padded entries. + + // Build all-ones tensor of same shape as the original input. + ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1); + auto all_ones_tensor = rewriter.create(loc, splat); + + // Get padding for the input. + DenseIntElementsAttr input_padding_attr = + GetReduceWindowPaddingAsAttr( + input_shape, op.ksize(), op.strides(), op.padding(), &rewriter); + + // Count the 1's in each window, using the same padding as for the input, + // which gives us the window counts by which `pooled` needs to be divided. + auto divisor = rewriter.create( + loc, pooled_type, + /*operand=*/all_ones_tensor, + /*init_value=*/zero, + /*window_dimensions=*/GetI64ElementsAttr(op.ksize()), + /*window_strides=*/GetI64ElementsAttr(op.strides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), + /*padding=*/input_padding_attr); + BuildReduceBody(element_type, &divisor.body(), &rewriter); + + // Divide `pooled` by window counts. + result = rewriter.create(loc, pooled_type, pooled, divisor); + } + return result; +} + +Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.value(); } +Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.input(); } + // Converts AvgPool op to HLO ReduceWindow op by setting appropriate window // dimensions with add as the reduction function. The reduction result is // then divided by the number of elements in the window. -class ConvertAvgPoolOp : public OpRewritePattern { +template +class ConvertAvgPoolOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::AvgPoolOp op, + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto input_type = op.value().getType().dyn_cast(); + Value input_value = GetAvgPoolInput(op); + auto input_type = + input_value.getType().template dyn_cast(); if (!input_type) return failure(); - // TODO(b/147217034): support other data formats. - if (!IsDefaultDataFormat(op.data_format())) return failure(); - // TODO(b/147217034): support "SAME" padding. - if (op.padding() != "VALID") return failure(); - // We will do accumulation first; use a larger bitwidth if suitable. Type input_element_type = input_type.getElementType(); Type sum_element_type = GetSumAccumulationType(input_element_type); Type result_type; // The result type for reduction and division with the proper element type. - if (auto ranked_type = op.getType().dyn_cast()) + if (auto ranked_type = op.getType().template dyn_cast()) result_type = RankedTensorType::get(ranked_type.getShape(), sum_element_type); else result_type = UnrankedTensorType::get(sum_element_type); - Value input_value = op.value(); - // Convert if we need enlarge the element type's bitwidth. if (input_element_type != sum_element_type) input_value = rewriter.create(op.getLoc(), input_value, @@ -1787,9 +2153,9 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Create the tf.ReduceWindow op. Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); - DenseIntElementsAttr paddings_attr = - GetReduceWindowPaddingAsAttr<4>(input_type.getShape(), op.ksize(), - op.strides(), op.padding(), &rewriter); + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( + input_type.getShape(), op.ksize(), op.strides(), op.padding(), + &rewriter); auto reduce = rewriter.create( op.getLoc(), result_type, input_value, init, GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), @@ -1799,19 +2165,17 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Count the number of elements in the window. The following calculation // is only valid for no paddings. - SmallVector ksize; + SmallVector input_shape( + llvm::to_vector(input_type.getShape())); + SmallVector ksize, strides; GetI64ArrayAttrValues(op.ksize(), &ksize); - int64_t count = std::accumulate(ksize.begin(), ksize.end(), 1, - std::multiplies()); + GetI64ArrayAttrValues(op.strides(), &strides); - // Divide by the number of elements in the window. - Value divisor = - GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); - Value result = rewriter.create( - op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims); + Operation *result_op = AvgPoolDivideByCount( + reduce.getResult(), input_shape, ksize, strides, op, init, rewriter); // Convert back if we enlarged the element type's bitwidth. + Value result = result_op->getOpResult(0); if (input_element_type != sum_element_type) result = rewriter.create(op.getLoc(), result, input_element_type); @@ -1821,6 +2185,9 @@ class ConvertAvgPoolOp : public OpRewritePattern { } }; +using ConvertAvgPool2DOp = ConvertAvgPoolOp; +using ConvertAvgPool3DOp = ConvertAvgPoolOp; + // `AvgPoolGradOp` is converted to the following operations: // 1. Divide each entry of the output gradient (the gradient for the previous // layer in backpropagation order) by the count of the corresponding window @@ -1894,59 +2261,13 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { auto orig_input_shape_values = orig_input_shape_attr.getValues(); DimVector orig_input_shape(orig_input_shape_values.begin(), orig_input_shape_values.end()); - RankedTensorType orig_input_type = - RankedTensorType::get(orig_input_shape, element_type); DimVector ksize, strides; GetI64ArrayAttrValues(op.ksize(), &ksize); GetI64ArrayAttrValues(op.strides(), &strides); Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter); - Operation *out_grad_divided = nullptr; - if (op.padding() == "VALID") { - // All window counts are equal here because we don't have padding - // (each entry of `out_grad` corresponds to a window that consists of - // original input entries only). - int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1, - std::multiplies()); - // Divide `out_grad` by window counts. - Value divisor = - GetScalarConstOfType(element_type, loc, window_count, &rewriter); - auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); - out_grad_divided = rewriter.create( - loc, out_grad_type, out_grad, divisor, scalar_broadcast_dims); - } else { - assert(op.padding() == "SAME"); - // For SAME padding, only original entries that contributed to a window - // are counted for the average of this window, not padded entries. - - // Build all-ones tensor of same shape as the original input. - ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1); - auto all_ones_tensor = rewriter.create(loc, splat); - - // Get the same padding as for the original input. - DenseIntElementsAttr orig_padding_attr = - GetReduceWindowPaddingAsAttr(orig_input_shape, op.ksize(), - op.strides(), op.padding(), - &rewriter); - - // Count the 1's in each window, using the same padding as for the - // original input, which gives us the window counts by which `out_grad` - // needs to be divided. - auto window_counts = rewriter.create( - loc, out_grad_type, - /*operand=*/all_ones_tensor, - /*init_value=*/zero, - /*window_dimensions=*/GetI64ElementsAttr(op.ksize()), - /*window_strides=*/GetI64ElementsAttr(op.strides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), - /*padding=*/orig_padding_attr); - BuildReduceBody(element_type, &window_counts.body(), &rewriter); - - // Divide `out_grad` by window counts. - out_grad_divided = rewriter.create(loc, out_grad_type, - out_grad, window_counts); - } + auto out_grad_divided = AvgPoolDivideByCount( + out_grad, orig_input_shape, ksize, strides, op, zero, rewriter); // Get same padding as for original input. PaddingArray orig_padding = GetReduceWindowPaddingAsArray( @@ -2325,19 +2646,21 @@ class ConvertSizeOp : public OpRewritePattern { if (!input_ty) return failure(); const int64_t rank = input_ty.getRank(); - auto result_type = op.getResult().getType(); - Operation *size = - GetScalarConstOfType(result_type.cast().getElementType(), - op.getLoc(), 1, &rewriter); + auto result_ty = op.getResult().getType(); + auto element_ty = result_ty.cast().getElementType(); + Value size = GetScalarConstOfType(element_ty, op.getLoc(), 1, &rewriter); for (int64_t i = 0; i < rank; ++i) { - auto dim = rewriter.create( - op.getLoc(), result_type, input, - rewriter.getIntegerAttr(rewriter.getIntegerType(32), i)); + auto i32_ty = rewriter.getIntegerType(32); + auto size_ty = RankedTensorType::get({}, i32_ty); + auto dim_index = rewriter.getIntegerAttr(i32_ty, i); + Value dim = rewriter.create(op.getLoc(), size_ty, + input, dim_index); + dim = rewriter.create(op.getLoc(), result_ty, dim); size = rewriter.create( - op.getLoc(), size->getResult(0), dim.getResult(), + op.getLoc(), size, dim, /*DenseIntElementsAttr=*/DenseIntElementsAttr()); } - rewriter.replaceOp(op, size->getResult(0)); + rewriter.replaceOp(op, size); return success(); } @@ -2380,7 +2703,8 @@ static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, rhs_type.getShape().drop_back(2), result_batch_shape_compile_time_extents); auto result_batch_shape = rewriter->create( - loc, lhs_splitted.head(), rhs_splitted.head(), + loc, shape::ShapeType::get(rewriter->getContext()), lhs_splitted.head(), + rhs_splitted.head(), /*error=*/nullptr); // Lambda which handles the broadcasting of one side to the common // leading-batch dimensions. @@ -2640,7 +2964,7 @@ class ConvertSplitVOp : public OpRewritePattern { SmallVector slices; slices.reserve(op.getNumResults()); - for (int i = 0; i < op.getNumResults(); ++i) { + for (int i = 0, end = op.getNumResults(); i < end; ++i) { end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i]; slices.push_back(rewriter.create( op.getLoc(), op.value(), GetI64ElementsAttr(begin_indices, &rewriter), @@ -2815,7 +3139,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // verifier. int64_t slicing_dim_size = op.begin().getType().cast().getShape()[0]; - auto input_rank = input_shape.size(); + const int input_rank = input_shape.size(); for (int d = slicing_dim_size; d < input_rank; ++d) { // We only support slicing major dimensions, so minor dimensions after // slicing dimensions are all sliced with their full sizes. @@ -2856,7 +3180,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { } // For non-slice dims, get the full slice of that dimension. - for (int d = slicing_dim_size; d < input_shape.size(); ++d) { + for (int d = slicing_dim_size, end = input_shape.size(); d < end; ++d) { slice_sizes.push_back(input_shape[d]); slice_begin_indices.push_back(zero); } @@ -3552,7 +3876,8 @@ class ConvertTileOp : public OpRewritePattern { multiples.getType().getRank() != 1) return failure(); - if (multiples.getNumElements() != input_shape.size()) return failure(); + const int64_t input_shape_size = input_shape.size(); + if (multiples.getNumElements() != input_shape_size) return failure(); SmallVector broadcasted_shape; SmallVector broadcast_dimensions; @@ -4339,7 +4664,7 @@ class ConvertUnpackOp : public OpRewritePattern { SmallVector results; results.reserve(op.getNumResults()); - for (int i = 0; i < op.getNumResults(); ++i) { + for (int i = 0, end = op.getNumResults(); i < end; ++i) { begin_indices[axis] = i; end_indices[axis] = i + 1; @@ -4698,7 +5023,12 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { SmallVector unpacked_indices_type( indices_type.getDimSize(0), RankedTensorType::get({}, indices_type.getElementType())); - auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(64), 0); + // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are + // required to have matching types. This rewrite rule creates + // DynamicUpdateSlice ops where the first "start index" is always i32 and + // subsequent ones are constructed based on zero_attr. Thus the type + // for zero_attr needs to be i32 as well. + auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0); auto unpacked_indices = rewriter.create( op.getLoc(), unpacked_indices_type, indices, zero_attr); @@ -4777,11 +5107,8 @@ class ConvertCumsumOp : public OpRewritePattern { return failure(); } - // TODO(jennik): Add support for the optional 'exclusive' and 'reverse' - // arguments. - if (op.exclusive() || op.reverse()) { - return failure(); - } + ArrayRef input_shape = input_type.getShape(); + int64_t rank = input_shape.size(); // We can only match when the axis is a constant scalar. DenseIntElementsAttr axis_attr; @@ -4789,15 +5116,6 @@ class ConvertCumsumOp : public OpRewritePattern { return failure(); } - // Convert if we need to enlarge the element type's bitwidth to avoid - // precision loss. - Type input_element_type = input_type.getElementType(); - Type sum_element_type = GetSumAccumulationType(input_element_type); - input = rewriter.create(op.getLoc(), input, sum_element_type); - - ArrayRef input_shape = input_type.getShape(); - int64_t rank = input_shape.size(); - // Get the dimension to apply the reduction on, and offset properly if it is // negative. int64_t axis = (*axis_attr.begin()).getSExtValue(); @@ -4805,6 +5123,21 @@ class ConvertCumsumOp : public OpRewritePattern { axis += rank; } + // If we're supposed to sum things up in the reverse direction, we reverse + // the input and then later reverse the output. + if (op.reverse()) { + llvm::SmallVector dims_to_reverse({axis}); + input = rewriter.create( + op.getLoc(), op.getType(), input, + GetI64ElementsAttr(dims_to_reverse, &rewriter)); + } + + // Convert if we need to enlarge the element type's bitwidth to avoid + // precision loss. + Type input_element_type = input_type.getElementType(); + Type sum_element_type = GetSumAccumulationType(input_element_type); + input = rewriter.create(op.getLoc(), input, sum_element_type); + SmallVector window_dims(rank, 1); SmallVector window_strides(rank, 1); window_dims[axis] = input_shape[axis]; @@ -4827,10 +5160,34 @@ class ConvertCumsumOp : public OpRewritePattern { BuildReduceBody(sum_element_type, &reduce.body(), &rewriter); Value result = reduce.getResult(); + if (op.exclusive()) { + // In "exclusive" operation, the output will start with the "init" (0) + // values. There is no way to express that as a ReduceWindowOp, so run the + // normal operation, and then use a PadOp to add the 0 "column" on the + // left and cut away the last column on the right. + llvm::SmallVector low_padding(rank, 0); + llvm::SmallVector high_padding(rank, 0); + llvm::SmallVector interior_padding(rank, 0); + low_padding[axis] = 1; + high_padding[axis] = -1; + result = rewriter.create( + op.getLoc(), op.getType(), result, init, + GetI64ElementsAttr(low_padding, &rewriter), + GetI64ElementsAttr(high_padding, &rewriter), + GetI64ElementsAttr(interior_padding, &rewriter)); + } + // Convert back if we enlarged the element type's bitwidth. result = rewriter.create(op.getLoc(), result, input_element_type); + if (op.reverse()) { + llvm::SmallVector dims_to_reverse({axis}); + result = rewriter.create( + op.getLoc(), op.getType(), result, + GetI64ElementsAttr(dims_to_reverse, &rewriter)); + } + rewriter.replaceOp(op, result); return success(); } @@ -5358,50 +5715,6 @@ class ConvertQrOp : public OpRewritePattern { } }; -// Converts `TF::SparseMatMulOp` to `TF::MatMulOp`, ignoring the sparseness -// hints, since we currently don't have an implementation that can use this -// information. Adds appropriate casts where necessary to align element types -// of operands and result for `TF::MatMulOp`. -class ConvertSparseMatMulOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TF::SparseMatMulOp op, - PatternRewriter &rewriter) const override { - // Result type must be f32 for applying the pattern (currently this is - // required by the op anyway but this might change). - if (!op.product().getType().cast().getElementType().isF32()) { - return failure(); - } - MLIRContext *context = rewriter.getContext(); - llvm::SmallVector operands{op.a(), op.b()}; - for (Value &operand : operands) { - TensorType tensor_type = operand.getType().cast(); - Type element_type = tensor_type.getElementType(); - if (element_type.isF32()) continue; - // Element type can either be f32 or bf16 for `SparseMatMulOp` so it - // must be bf16 here. - assert(element_type.isBF16()); - Type tensor_type_f32; - if (tensor_type.hasRank()) { - tensor_type_f32 = RankedTensorType::get(tensor_type.getShape(), - FloatType::getF32(context)); - } else { - tensor_type_f32 = UnrankedTensorType::get(FloatType::getF32(context)); - } - // Add cast to f32 to conform with element type of result. - operand = - rewriter.create(op.getLoc(), tensor_type_f32, operand); - } - Value result = rewriter.create( - op.getLoc(), op.product().getType(), operands[0], operands[1], - op.transpose_a(), op.transpose_b()); - - rewriter.replaceOp(op, {result}); - return success(); - } -}; - // Emits debug information which includes the number of ops of each type which // failed to legalize. void EmitLegalizationErrors(Operation *op, @@ -5449,9 +5762,14 @@ void EmitLegalizationErrors(Operation *op, // Performs the lowering to XLA dialect. void LegalizeTF::runOnFunction() { - if (failed( - legalizeTF(getFunction(), allow_partial_conversion_, legalize_chlo_))) + llvm::Optional tf2xla_fallback_device_type = llvm::None; + if (use_tf2xla_fallback_) { + tf2xla_fallback_device_type = device_type_; + } + if (failed(legalizeTF(getFunction(), allow_partial_conversion_, + legalize_chlo_, tf2xla_fallback_device_type))) { signalPassFailure(); + } } static PassRegistration pass( @@ -5461,53 +5779,48 @@ static PassRegistration pass( #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" -LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, - bool legalize_chlo) { +LogicalResult legalizeTF( + Operation *op, bool allow_partial_conversion, bool legalize_chlo, + llvm::Optional tf2xla_fallback_device_type) { MLIRContext *context = op->getContext(); - - // Add lowering patterns to the list. OwningRewritePatternList patterns; - populateWithGenerated(context, &patterns); + // Note that the `OperationConverter` orders patterns lexicographically by: + // 1) Ascending legalization depth (i.e., minimum number of patterns necessary + // to arrive at conversion target). + // 2) Descending pattern benefit. + // 3) Order of patterns in `OwningRewritePatternList`. - // Add patterns that lower some of the high level TensorFlow ops to lower - // level TensorFlow ops. So, we don't have to target all the TensorFlow ops - // here for lowering to HLO. + // Add TF->HLO legalization patterns. + PopulateLegalizeTfPatterns(context, &patterns); + + // Add TF->TF lowering patterns. TF::PopulateLoweringTFPatterns(context, &patterns); - patterns.insert< - ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, - ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, - ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, - ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, - ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, - ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, - ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, - ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV3Op, - ConvertInfeedDequeueTupleOp, ConvertInplaceUpdateOp, ConvertLinSpaceOp, - ConvertMaxOp, ConvertMinOp, ConvertAvgPoolOp, ConvertAvgPool2DGradOp, - ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, - ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, - ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, - ConvertDynamicRangeOp, ConvertRangeOp, ConvertSelectV2Op, - ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, - ConvertSoftmaxOp, - ConvertSoftmaxOp, ConvertSparseMatMulOp, - ConvertSplitOp, ConvertSplitVOp, ConvertStridedSliceOp, - ConvertStridedSliceGradOp, ConvertSumOp, ConvertTensorScatterUpdateOp, - ConvertTileOp, ConvertTopKV2Op, ConvertUnpackOp, - ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp, - ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp, - ConvertRandomShuffleOp, ConvertXlaShardingOp, - ConvertXlaDynamicUpdateSliceOp>(op->getContext()); + + // Add TF->HLO legalization patterns via TF2XLA fallback. + if (tf2xla_fallback_device_type.hasValue()) { + PopulateLegalizeTfWithTf2XlaPatterns(tf2xla_fallback_device_type.getValue(), + patterns); + } // Populate with CHLO->HLO lowerings to account for TF ops legalized to // CHLO first. if (legalize_chlo) { chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); } + // ConstantLike op is convenient to create splat constants, but is + // canonicalized to plain HLO constant if statically shaped. Add the + // canonicalization pattern to pattern list to enable multi-hop lowering. + chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context); ConversionTarget target(*context); if (legalize_chlo) { target.addIllegalDialect(); + + // Mark ConstantLikeOp as dynamically legal only when it doesn't have a + // static result type so that it gets canonicalized to MHLO constant. + target.addDynamicallyLegalOp([](Operation *op) { + return !op->getResultTypes().front().cast().hasStaticShape(); + }); } else { target.addLegalDialect(); } @@ -5535,9 +5848,41 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, return applyPartialConversion(op, target, patterns); } +void PopulateLegalizeTfPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + populateWithGenerated(context, patterns); + patterns->insert< + ConvertAllOp, ConvertAnyOp, ConvertArgMaxOp, ConvertBatchMatMulV2Op, + ConvertBiasAddOp, ConvertBroadcastToOp, ConvertBF16FloorDivOp, + ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, + ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, + ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, + ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, + ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, + ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op, + ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, + ConvertInplaceUpdateOp, ConvertLinSpaceOp, ConvertMaxOp, ConvertMinOp, + ConvertAvgPool2DOp, ConvertAvgPool3DOp, ConvertAvgPool2DGradOp, + ConvertAvgPool3DGradOp, ConvertMaxPool2DOp, ConvertMaxPool3DOp, + ConvertMaxPool2DGradOp, ConvertMaxPool3DGradOp, ConvertMeanOp, + ConvertOneHotOp, ConvertOutfeedEnqueueTupleOp, ConvertProdOp, ConvertQrOp, + ConvertDynamicRangeOp, ConvertMatrixDiagPartV3Op, ConvertRangeOp, + ConvertSelectV2Op, ConvertSigmoidOp, ConvertShapeOp, ConvertSizeOp, + ConvertSoftmaxOp, + ConvertSoftmaxOp, ConvertSplitOp, ConvertSplitVOp, + ConvertStridedSliceOp, ConvertStridedSliceGradOp, ConvertSumOp, + ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op, + ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp, + ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp, + ConvertRandomShuffleOp, ConvertXlaShardingOp, + ConvertXlaDynamicUpdateSliceOp>(context); +} + std::unique_ptr> createLegalizeTFPass( - bool allow_partial_conversion, bool legalize_chlo) { - return std::make_unique(allow_partial_conversion, legalize_chlo); + bool allow_partial_conversion, bool legalize_chlo, + llvm::Optional tf2xla_fallback_device_type) { + return std::make_unique(allow_partial_conversion, legalize_chlo, + tf2xla_fallback_device_type); } } // end namespace mhlo diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc new file mode 100644 index 00000000000..1d6ce36300f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc @@ -0,0 +1,907 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering TensorFlow dialect's communication +// ops (TF/XLA) to the HLO dialect. + +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/xla/type_to_shape.h" +#include "tensorflow/compiler/xla/client/sharding_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" + +namespace mlir { +namespace mhlo { + +namespace { +constexpr char kShardingAttr[] = "mhlo.sharding"; +constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes"; +const char kXlaHostTransferRendezvousNameAttr[] = + "_xla_host_transfer_rendezvous"; +const char kXlaHostTransferOriginalTypeAttr[] = + "_xla_host_transfer_original_type"; + +// A pass that legalizes TF/XLA communication ops, propagate their respective +// tokens (for ordering), and rewrite their respective functions and control +// flow ops when necessary. +// Note, this currently does not handle nested modules/functions or region based +// ops other than certain control flow ops (`mhlo.if`, `mhlo.while`). +class LegalizeTFCommunication + : public PassWrapper> { + public: + void runOnOperation() override; +}; + +// Checks if an op is a TF/XLA communication op. +bool IsCommunicationOp(Operation* op) { + return isa(op); +} + +// Checks if an op is a supported HLO control flow op. +bool IsControlFlowOp(Operation* op) { return isa(op); } + +// Collects control flow op ancestors of a given op, up until FuncOp. If any +// ancestor is not a control flow op or a FuncOp, or of a single block region, +// an error will be returned. +LogicalResult GetControlFlowAncestors( + Operation* op, llvm::SmallPtrSetImpl& control_flow_ops, + llvm::SmallPtrSetImpl& control_flow_blocks) { + Block* block = op->getBlock(); + Operation* parent = block->getParentOp(); + while (block && parent && !isa(parent)) { + if (!IsControlFlowOp(parent)) + return op->emitOpError() + << "expects ancestor(s) to be of ['" << IfOp::getOperationName() + << "', '" << FuncOp::getOperationName() << "']"; + + if (!llvm::hasSingleElement(block->getParent()->getBlocks())) + return op->emitOpError() << "expects single block region ancestor(s)"; + + control_flow_ops.insert(parent); + control_flow_blocks.insert(block); + + parent = block->getParentOp(); + block = parent->getBlock(); + } + return success(); +} + +// Finds communication ops in a function. `control_flow_ops` and +// `control_flow_blocks` will be populated with control flow op ancestors for +// every communication op. +LogicalResult FindCommunicationOps( + FuncOp func, llvm::SmallPtrSetImpl& control_flow_ops, + llvm::SmallPtrSetImpl& control_flow_blocks, + bool& has_communication_ops) { + auto result = func.walk([&](Operation* op) { + if (!IsCommunicationOp(op)) return WalkResult::advance(); + has_communication_ops = true; + if (failed( + GetControlFlowAncestors(op, control_flow_ops, control_flow_blocks))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +// Helper struct holding a function to be rewritten, it's control flow ops that +// lead to a communication op or function call with a communication op +// (transitively), and an optional clone of itself. If `clone` is set, function +// calls to `original` will be replaced with `clone`. +struct FuncToRewrite { + FuncOp original; + llvm::SmallPtrSet control_flow_ops; + llvm::SmallPtrSet control_flow_blocks; + FuncOp clone; +}; + +// Finds all functions that need to be rewritten with communication ops and +// and associated tokens. +LogicalResult GetFunctionsToRewrite( + ModuleOp module, + llvm::SmallDenseMap& funcs_to_rewrite) { + // Find functions containing communication ops. + SmallVector funcs_to_visit; + for (FuncOp func : module.getOps()) { + FuncToRewrite func_to_rewrite{/*original=*/func, /*control_flow_ops=*/{}, + /*control_flow_blocks=*/{}, + /*clone=*/nullptr}; + bool has_communication_ops = false; + if (failed(FindCommunicationOps(func, func_to_rewrite.control_flow_ops, + func_to_rewrite.control_flow_blocks, + has_communication_ops))) + return failure(); + + if (!has_communication_ops) continue; + funcs_to_rewrite.insert({func.getName(), func_to_rewrite}); + funcs_to_visit.push_back(func); + } + + // Find functions that call functions with communication ops, transitively. + while (!funcs_to_visit.empty()) { + SmallVector new_funcs_to_visit; + for (FuncOp& func : funcs_to_visit) { + auto uses = func.getSymbolUses(module); + if (!uses) continue; + for (auto& use : *uses) { + // Only `mlir::CallOp` is supported as this requires knowing how to + // rewrite arguments and results to a function. + if (!isa(use.getUser())) continue; + auto caller_parent_func = use.getUser()->getParentOfType(); + if (!caller_parent_func) continue; + + FuncToRewrite func_to_rewrite{/*original=*/caller_parent_func, + /*control_flow_ops=*/{}, + /*control_flow_blocks=*/{}, + /*clone=*/nullptr}; + if (failed(GetControlFlowAncestors( + use.getUser(), func_to_rewrite.control_flow_ops, + func_to_rewrite.control_flow_blocks))) + return failure(); + + auto it = funcs_to_rewrite.insert( + {caller_parent_func.getName(), func_to_rewrite}); + if (it.second) { + new_funcs_to_visit.push_back(caller_parent_func); + } else { + it.first->getSecond().control_flow_ops.insert( + func_to_rewrite.control_flow_ops.begin(), + func_to_rewrite.control_flow_ops.end()); + it.first->getSecond().control_flow_blocks.insert( + func_to_rewrite.control_flow_blocks.begin(), + func_to_rewrite.control_flow_blocks.end()); + } + } + } + + funcs_to_visit.swap(new_funcs_to_visit); + } + + // Clone public functions that need to be rewritten. Function calls to this + // function will be replaced with the cloned function. + SymbolTable symbol_table(module); + for (auto& func : funcs_to_rewrite) { + if (func.getSecond().original.isPublic() && + !func.getSecond().original.symbolKnownUseEmpty(module)) { + auto clone = func.getSecond().original.clone(); + clone.setVisibility(SymbolTable::Visibility::Private); + symbol_table.insert(clone); + func.getSecond().clone = clone; + } + } + + return success(); +} + +// Assigns op sharding to an op for a given device core. +void SetOpSharding(Operation* op, int64_t tpu_core) { + std::string sharding_serialized = + ::xla::sharding_builder::AssignDevice(tpu_core).SerializeAsString(); + op->setAttr(kShardingAttr, + StringAttr::get(sharding_serialized, op->getContext())); +} + +// Assigns frontend attributes holding information about data type and +// TensorFlow rendezvous channel name. +void SetFrontendAttributes(Operation* op, StringRef key, Type type) { + MLIRContext* context = op->getContext(); + + auto rendezvous_name = StringAttr::get(key, context); + auto rendezvous_name_attr = NamedAttribute( + Identifier::get(kXlaHostTransferRendezvousNameAttr, context), + rendezvous_name); + + auto element_type = getElementTypeOrSelf(type); + auto xla_element_type = ::xla::TypeToPrimitiveType(element_type); + const std::string& xla_element_type_str = + ::xla::primitive_util::LowercasePrimitiveTypeName(xla_element_type); + auto original_type = StringAttr::get(xla_element_type_str, context); + auto original_type_attr = + NamedAttribute(Identifier::get(kXlaHostTransferOriginalTypeAttr, context), + original_type); + + auto frontend_attributes = DictionaryAttr::get( + ArrayRef{rendezvous_name_attr, original_type_attr}, + context); + op->setAttr(kFrontendAttributesAttr, frontend_attributes); +} + +// Assigns frontend attributes holding information about data type and +// TensorFlow rendezvous channel name specific to `tf._XlaHostComputeMlir`. +// TensorFlow rendezvous channel name is handled differently as individual names +// are used per data send and receive. +void SetFrontendAttributes(Operation* op, int32_t index, StringRef key, + Type type, bool device_to_host) { + std::string formatted_key = + device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str() + : llvm::formatv("{0}_htod_{1}", key, index).str(); + + return SetFrontendAttributes(op, formatted_key, type); +} + +// Creates a `mhlo.send` op for sending value `operand`. If `index` is set, +// `key` will be rewritten with a suffix and index. If `tpu_core` is set, op +// sharding for the respective device will be set. +Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc, + Value operand, StringRef key, const Optional& index, + const Optional& tpu_core, Value token) { + // type 2 == DEVICE_TO_HOST + auto channel_handle = ChannelHandle::get( + /*handle=*/builder.getI64IntegerAttr(channel_id++), + /*type=*/builder.getI64IntegerAttr(2), builder.getContext()); + auto send = builder.create( + loc, token.getType(), operand, token, channel_handle, + /*is_host_transfer=*/builder.getBoolAttr(true)); + + if (index) { + SetFrontendAttributes(send, *index, key, operand.getType(), + /*device_to_host=*/true); + } else { + SetFrontendAttributes(send, key, operand.getType()); + } + + if (tpu_core) SetOpSharding(send, *tpu_core); + + return send.getResult(); +} + +// Creates a `mhlo.recv` op for receiving a value. If `index` is set, `key` will +// be rewritten with a suffix and index. If `tpu_core` is set, op sharding for +// the respective device will be set. +Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc, + Value result, StringRef key, const Optional& index, + const Optional& tpu_core, Value token) { + // type 3 == HOST_TO_DEVICE + auto channel_handle = ChannelHandle::get( + /*handle=*/builder.getI64IntegerAttr(channel_id++), + /*type=*/builder.getI64IntegerAttr(3), builder.getContext()); + auto result_type = result.getType(); + auto recv_result_type = + TupleType::get({result_type, token.getType()}, builder.getContext()); + auto recv = + builder.create(loc, recv_result_type, token, channel_handle, + /*is_host_transfer=*/builder.getBoolAttr(true)); + if (index) { + SetFrontendAttributes(recv, *index, key, result_type, + /*device_to_host=*/false); + } else { + SetFrontendAttributes(recv, key, result.getType()); + } + if (tpu_core) SetOpSharding(recv, *tpu_core); + + auto get_tuple_element = + builder.create(loc, recv.getResult(), /*index=*/0); + if (tpu_core) SetOpSharding(get_tuple_element, *tpu_core); + + result.replaceAllUsesWith(get_tuple_element); + + auto new_token = builder.create(loc, recv.getResult(), + /*index=*/1); + if (tpu_core) SetOpSharding(new_token, *tpu_core); + + return new_token.getResult(); +} + +// Creates a new token if necessary, acting as a sink to previous tokens. If +// there is only one token in `tokens`, the only token is returned. If `tokens` +// is empty, `original_token` is returned instead. +Value CreateSinkToken(OpBuilder& builder, Location loc, ArrayRef tokens, + Value original_token) { + if (tokens.empty()) { + return original_token; + } else if (llvm::hasSingleElement(tokens)) { + return tokens[0]; + } else { + return builder.create(loc, original_token.getType(), tokens) + .getResult(); + } +} + +// Replaces `tf._XlaHostComputeMlir` with individual `mhlo.send` and `mhlo.recv` +// ops per operand and result. Unique Channel Id's are assigned per transfer. +// Sink tokens are created across all `mhlo.send` ops first and then by +// all `mhlo.recv` ops. +Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id, + TF::_XlaHostComputeMlirOp host_compute, + Value token) { + builder.setInsertionPoint(host_compute); + Location loc = host_compute.getLoc(); + int64_t tpu_core = host_compute.tpu_coreAttr().getInt(); + + SmallVector send_tokens; + for (auto operand : llvm::enumerate(host_compute.inputs())) { + auto send_token = + CreateSendOp(builder, channel_id, loc, operand.value(), + host_compute.send_key(), operand.index(), tpu_core, token); + send_tokens.push_back(send_token); + } + token = CreateSinkToken(builder, loc, send_tokens, token); + + SmallVector recv_tokens; + for (auto result : llvm::enumerate(host_compute.outputs())) { + auto recv_token = + CreateRecvOp(builder, channel_id, loc, result.value(), + host_compute.recv_key(), result.index(), tpu_core, token); + recv_tokens.push_back(recv_token); + } + token = CreateSinkToken(builder, loc, recv_tokens, token); + + host_compute.erase(); + return token; +} + +// Replaces `tf.XlaSendToHost` with a `mhlo.send`. +Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id, + TF::XlaSendToHostOp send_to_host, Value token) { + builder.setInsertionPoint(send_to_host); + token = CreateSendOp(builder, channel_id, send_to_host.getLoc(), + send_to_host.input(), send_to_host.key(), + /*index=*/llvm::None, /*tpu_core=*/llvm::None, token); + + send_to_host.erase(); + return token; +} + +// Replaces `tf.XlaRecvFromHost` with a `mhlo.recv`. +Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id, + TF::XlaRecvFromHostOp recv_from_host, Value token) { + builder.setInsertionPoint(recv_from_host); + token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(), + recv_from_host.output(), recv_from_host.key(), + /*index=*/llvm::None, /*tpu_core=*/llvm::None, token); + + recv_from_host.erase(); + return token; +} + +// Replaces a `mlir::CallOp` with one that has an extra `!mhlo.token` operand +// and `!mhlo.token` result. If `new_symbol` is set, the new call will be +// updated to call the `new_symbol` instead. +Value RewriteCallOp(OpBuilder& builder, CallOp call, + const Optional& new_symbol, Value token) { + builder.setInsertionPoint(call); + auto new_operands = llvm::to_vector<4>(call.getArgOperands()); + new_operands.push_back(token); + auto new_result_types = llvm::to_vector<4>(call.getResultTypes()); + new_result_types.push_back(token.getType()); + auto new_call = builder.create( + call.getLoc(), new_result_types, new_symbol ? *new_symbol : call.callee(), + new_operands); + + for (auto results : llvm::zip(call.getResults(), new_call.getResults())) + std::get<0>(results).replaceAllUsesWith(std::get<1>(results)); + call.erase(); + return new_call.getResults().back(); +} + +// Helper struct holding state of which op to visit to next. If `op` is in a +// control flow op region, `region_idx` will be set with the respective region +// index. `token` will be current token from the last communication op/control +// flow op transitive communication ops. +struct OpVisitorState { + Optional region_idx; + Value token; + Operation* op; +}; + +// Creates a tuple from a sequence of values. +Value CreateTuple(OpBuilder& builder, Location loc, ArrayRef operands) { + return builder.create(loc, operands).getResult(); +} + +// Replaces a value `value` with a new value but the token attached. If `value` +// is not a tuple, a new tuple is formed with `token`. If `value` is a tuple, +// `value` is extended instead. New tuple values created are cached. +Value GetValueWithToken(OpBuilder& builder, Value value, Value token, + llvm::SmallDenseMap& rewritten_values) { + // If value with token already exists, reuse it. + auto it = rewritten_values.find(value); + if (it != rewritten_values.end()) return it->getSecond(); + + auto create_tuple = [&](ArrayRef operands) { + auto new_result = CreateTuple(builder, value.getLoc(), operands); + rewritten_values.insert({value, new_result}); + return new_result; + }; + + auto tuple_type = value.getType().dyn_cast(); + // `value` is not a tuple, create a new tuple. + if (!tuple_type) return create_tuple({value, token}); + + // Extend tuple if `value` is a tuple. + // If `value` is an op result and the owner is a `mhlo.tuple`, simply unpack + // the tuple. + if (auto tuple_op = value.getDefiningOp()) { + auto tuple_operands = llvm::to_vector<4>(tuple_op.getOperands()); + tuple_operands.push_back(token); + return create_tuple(tuple_operands); + } + + // `value` is not created via a `mhlo.tuple` directly, unpack individual + // elements directly with `mhlo.get_tuple_element`. + SmallVector tuple_operands; + for (auto idx : llvm::seq(0, tuple_type.getTypes().size())) + tuple_operands.push_back( + builder.create(value.getLoc(), value, idx) + .getResult()); + + tuple_operands.push_back(token); + return create_tuple(tuple_operands); +} + +// Extends a type to include a `mhlo.token` type. If `type` is not a tuple type, +// a new tuple type with `type` and `mhlo.token` type is created instead. +TupleType GetTypeWithToken(OpBuilder& builder, Type type) { + auto token_type = TokenType::get(builder.getContext()); + if (auto tuple_type = type.dyn_cast()) { + auto result_types = llvm::to_vector<4>(tuple_type.getTypes()); + result_types.push_back(token_type); + return builder.getTupleType(result_types); + } + + return builder.getTupleType({type, token_type}); +} + +// Creates a slice of a tuple `value` with `mhlo.get_tuple_element` from index 0 +// to `end`, exclusive. +Value CreateSubTuple(OpBuilder& builder, Value value, size_t end) { + SmallVector tuple_operands; + for (auto idx : llvm::seq(0, end)) + tuple_operands.push_back( + builder.create(value.getLoc(), value, idx) + .getResult()); + + return CreateTuple(builder, value.getLoc(), tuple_operands); +} + +// Replaces uses of `value` with `replacement`. If `value` is not a tuple type, +// an explicit `mhlo.get_tuple_element` is created to unpack the tuple and +// return the first element. Otherwise, `mhlo.get_tuple_element` users are +// simply updated with `replacement`, and all other users are updated with a +// slice of `replacement`. +void ReplaceWithTupleResult(OpBuilder& builder, Value value, + Value replacement) { + auto tuple_type = value.getType().dyn_cast(); + if (!tuple_type) { + if (!value.use_empty()) { + auto new_element = builder.create(replacement.getLoc(), + replacement, 0); + value.replaceAllUsesWith(new_element.getResult()); + } + return; + } + + Value sub_tuple; + for (auto& use : llvm::make_early_inc_range(value.getUses())) { + if (isa(use.getOwner())) { + use.set(replacement); + continue; + } + + if (!sub_tuple) + sub_tuple = CreateSubTuple(builder, replacement, tuple_type.size()); + + use.set(sub_tuple); + } +} + +// Replaces control flow op block single block argument with new block argument +// of type `new_type` (tuple type). The last element of the new block argument +// (token) is returned. +Value UpdateControlFlowBlockArgWithToken(OpBuilder& builder, Block& block, + Type token_type) { + assert(block.getNumArguments() == 1); + builder.setInsertionPointToStart(&block); + auto new_arg = block.addArgument(token_type); + ReplaceWithTupleResult(builder, block.getArgument(0), new_arg); + block.eraseArgument(0); + return builder + .create(new_arg.getLoc(), new_arg, + token_type.cast().size() - 1) + .getResult(); +} + +// Updates control flow op terminator with an extra element `token`. If the +// original return value is not a tuple, a new tuple is formed. Otherwise the +// tuple is extended. +void RewriteControlFlowTerminator(OpBuilder& builder, Operation* terminator, + Value token) { + assert(terminator->getNumOperands() == 1); + assert(terminator->getBlock()->getNumArguments() == 1); + // `mhlo.while` cond terminator does not need to be rewritten as it always + // returns a tensor predicate value. + if (auto while_parent = dyn_cast_or_null(terminator->getParentOp())) + if (terminator->getParentRegion() == &while_parent.cond()) return; + + builder.setInsertionPoint(terminator); + llvm::SmallDenseMap rewritten_operands; + Value new_result = GetValueWithToken(builder, terminator->getOperand(0), + token, rewritten_operands); + terminator->setOperand(0, new_result); +} + +// Rewrites a `mhlo.if` op to receive and forward a `mhlo.token`. Operands to +// the op for all of its regions are extended to have an extra operand `token`. +void RewriteRegionIfOp(OpBuilder& builder, IfOp region_if, + SmallVectorImpl& ops_to_visit, + Value token) { + llvm::SmallDenseMap rewritten_operands; + + // Rewrite all region operands to have an extra operand `token`. + Value new_true_operand = GetValueWithToken(builder, region_if.true_arg(), + token, rewritten_operands); + Value new_false_operand = GetValueWithToken(builder, region_if.false_arg(), + token, rewritten_operands); + + auto new_result_type = GetTypeWithToken(builder, region_if.getType()); + + // Create new `mhlo.if` op with extra token operands and result. + auto new_if = builder.create(region_if.getLoc(), new_result_type, + region_if.pred(), new_true_operand, + new_false_operand); + + // Move all regions from the old `mhlo.if` op to its replacement. + new_if.true_branch().takeBody(region_if.true_branch()); + new_if.false_branch().takeBody(region_if.false_branch()); + + // Forward result from old `mhlo.if` with replacement, and unpack result when + // necessary. + ReplaceWithTupleResult(builder, region_if.getResult(), new_if.getResult()); + + auto new_token = builder.create( + new_if.getLoc(), new_if.getResult(), + new_if.getResult().getType().cast().size() - 1); + + region_if.erase(); + + // Remove leftover operands to old `mhlo.if` if they have no uses. + for (auto& rewritten_operand : rewritten_operands) + if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp()) + if (tuple_op.use_empty()) tuple_op.erase(); + + // Next op to visit. The replacement is visited but at its first region. The + // token result of the new region if is propagated. + ops_to_visit.push_back({/*region_idx=*/0, new_token, new_if}); +} + +// Rewrites a `mhlo.if`/`mhlo.while` region to receive and forward a +// `mhlo.token`. The block argument is updated to have an extra `mhlo.token` +// element. If the region block is to be rewritten, the next op to visit is set +// to the first op in the block. Otherwise the terminator is updated to forward +// `token`. +void RewriteControlFlowOpRegion( + OpBuilder& builder, Operation* region_op, unsigned region_idx, + Type block_arg_type, SmallVectorImpl& ops_to_visit, + const llvm::SmallPtrSetImpl& control_flow_blocks, Value token) { + ops_to_visit.push_back({region_idx + 1, token, region_op}); + + Region& region = region_op->getRegion(region_idx); + assert(llvm::hasSingleElement(region)); + + auto block_token = UpdateControlFlowBlockArgWithToken(builder, region.front(), + block_arg_type); + + if (control_flow_blocks.contains(®ion.front())) { + ops_to_visit.push_back({/*region_idx=*/llvm::None, block_token, + block_token.getDefiningOp()->getNextNode()}); + return; + } + + RewriteControlFlowTerminator(builder, region.front().getTerminator(), + block_token); +} + +// Rewrites an `mhlo.if` op or its region. If `region_idx` is not set, the op +// operands and results are rewritten. If `region_idx` is set, region +// `region_idx` is rewritten to take in and return an additional token. Returns +// true if the op or its region was rewritten. +bool ProcessRegionIfOp(OpBuilder& builder, IfOp region_if, + Optional region_idx, + SmallVectorImpl& ops_to_visit, + const llvm::SmallPtrSetImpl& control_flow_blocks, + Value token) { + builder.setInsertionPoint(region_if); + + if (!region_idx) { + RewriteRegionIfOp(builder, region_if, ops_to_visit, token); + return true; + } + + if (*region_idx < region_if.getNumRegions()) { + RewriteControlFlowOpRegion(builder, region_if, *region_idx, + region_if.getOperand(*region_idx + 1).getType(), + ops_to_visit, control_flow_blocks, token); + return true; + } + + return false; +} + +// Rewrites a `mhlo.while` op to receive and forward a `mhlo.token`. Operands to +// the op for all of its regions are extended to have an extra operand `token`. +void RewriteRegionWhileOp(OpBuilder& builder, WhileOp region_while, + SmallVectorImpl& ops_to_visit, + Value token) { + llvm::SmallDenseMap rewritten_operands; + + // Rewrite region operand to have an extra operand `token`. + Value new_val_operand = + GetValueWithToken(builder, region_while.val(), token, rewritten_operands); + + auto new_result_type = GetTypeWithToken(builder, region_while.getType()); + + // Create new `mhlo.while` op with extra token operand and result. + auto new_while = builder.create(region_while.getLoc(), + new_result_type, new_val_operand); + + // Move all regions from the old `mhlo.while` op to its replacement. + new_while.cond().takeBody(region_while.cond()); + new_while.body().takeBody(region_while.body()); + + // Forward result from old `mhlo.while` with replacement, and unpack result + // when necessary. + ReplaceWithTupleResult(builder, region_while.getResult(), + new_while.getResult()); + + auto new_token = builder.create( + new_while.getLoc(), new_while.getResult(), + new_while.getResult().getType().cast().size() - 1); + + region_while.erase(); + + // Remove leftover operands to old `mhlo.while` if they have no uses. + for (auto& rewritten_operand : rewritten_operands) + if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp()) + if (tuple_op.use_empty()) tuple_op.erase(); + + // Next op to visit. The replacement is visited but at its first region. The + // token result of the new region if is propagated. + ops_to_visit.push_back({/*region_idx=*/0, new_token, new_while}); +} + +// Rewrites an `mhlo.while` op or its region. If `region_idx` is not set, the op +// operands and results are rewritten. If `region_idx` is set, region +// `region_idx` is rewritten to take in and return an additional token. Returns +// true if the op or its region was rewritten. +bool ProcessRegionWhileOp( + OpBuilder& builder, WhileOp region_while, Optional region_idx, + SmallVectorImpl& ops_to_visit, + const llvm::SmallPtrSetImpl& control_flow_blocks, Value token) { + builder.setInsertionPoint(region_while); + + if (!region_idx) { + RewriteRegionWhileOp(builder, region_while, ops_to_visit, token); + return true; + } + + if (*region_idx < region_while.getNumRegions()) { + RewriteControlFlowOpRegion(builder, region_while, *region_idx, + region_while.val().getType(), ops_to_visit, + control_flow_blocks, token); + return true; + } + + return false; +} + +// Updates function type based on current function body block arguments and +// terminator operand types. +void UpdateFunctionType(OpBuilder& builder, FuncOp func, Block& func_body) { + auto new_argument_types = llvm::to_vector<4>(func_body.getArgumentTypes()); + auto new_result_types = + llvm::to_vector<4>(func_body.getTerminator()->getOperandTypes()); + func.setType(FunctionType::get(new_argument_types, new_result_types, + builder.getContext())); +} + +// Replaces a function terminator `return` with another `return` that has an +// extra `mhlo.token` operand. +void RewriteFunctionTerminator(OpBuilder& builder, mlir::ReturnOp terminator, + Value token) { + auto new_results = llvm::to_vector<4>(terminator.getOperands()); + new_results.push_back(token); + builder.setInsertionPoint(terminator); + builder.create(terminator.getLoc(), new_results); + terminator.erase(); +} + +// Rewrites a function body and communication ops inside. Region control flow +// are updated when necessary, to propagate tokens. The function may either be +// rewritten to create a token or take in and return a token, depending on its +// visibility and if there are any callers. +LogicalResult RewriteFunction( + OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func, + const llvm::SmallDenseMap& funcs, + const llvm::SmallPtrSetImpl& control_flow_ops, + const llvm::SmallPtrSetImpl& control_flow_blocks, bool is_clone) { + MLIRContext* context = module.getContext(); + if (!llvm::hasSingleElement(func.getBody())) + return func.emitError() + << "'" << FuncOp::getOperationName() + << "' ops with more than one block are not supported"; + + bool rewrite_block = + is_clone || (!func.isPublic() && !func.symbolKnownUseEmpty(module)); + Block& func_body = func.front(); + + builder.setInsertionPointToStart(&func_body); + auto token_type = TokenType::get(context); + // If a function is public, it's signature should not be modified, and instead + // a token will be created. Otherwise a token block argument is inserted. + Value init_token = + rewrite_block ? func_body.addArgument(token_type) + : builder.create(func.getLoc(), token_type) + .getResult(); + + // Stack to keep track of region based control flow op nesting and current + // op to visit. + SmallVector ops_to_visit{ + {/*region_idx=*/llvm::None, init_token, &func_body.front()}}; + + while (!ops_to_visit.empty()) { + OpVisitorState op_to_visit = ops_to_visit.pop_back_val(); + Operation* curr_op = op_to_visit.op; + + Value token = op_to_visit.token; + // Ops may be removed, so the next op is kept track of beforehand. + Operation* next_op = curr_op->getNextNode(); + + if (auto host_compute = dyn_cast(curr_op)) { + token = RewriteHostComputeOp(builder, channel_id, host_compute, token); + } else if (auto send_to_host = dyn_cast(curr_op)) { + token = RewriteSendToHostOp(builder, channel_id, send_to_host, token); + } else if (auto recv_from_host = dyn_cast(curr_op)) { + token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token); + } else if (auto call = dyn_cast(curr_op)) { + // Only `mlir::CallOp` is supported as this requires knowing how to + // rewrite arguments and results to a function. + auto it = funcs.find(call.getCallee()); + if (it != funcs.end()) { + FuncOp clone = it->getSecond().clone; + Optional symbol_name = + clone ? Optional(clone.getName()) : llvm::None; + // If the function being called is to be cloned, update the call to also + // point to the cloned function. + token = RewriteCallOp(builder, call, symbol_name, token); + } + } else if (auto region_if = dyn_cast(curr_op)) { + if (op_to_visit.region_idx || control_flow_ops.contains(region_if)) + if (ProcessRegionIfOp(builder, region_if, op_to_visit.region_idx, + ops_to_visit, control_flow_blocks, token)) + continue; + } else if (auto region_while = dyn_cast(curr_op)) { + if (op_to_visit.region_idx || control_flow_ops.contains(region_while)) + if (ProcessRegionWhileOp(builder, region_while, op_to_visit.region_idx, + ops_to_visit, control_flow_blocks, token)) + continue; + } else if (auto region_terminator = dyn_cast(curr_op)) { + RewriteControlFlowTerminator(builder, region_terminator, token); + // There is no next op afer the control flow op terminator, simply let + // stack have one less element. + continue; + } else if (auto func_terminator = dyn_cast(curr_op)) { + if (rewrite_block) + RewriteFunctionTerminator(builder, func_terminator, token); + + // There is no next op afer the function terminator, simply let stack have + // one less element/be empty. + continue; + } + + // Visit next op. + ops_to_visit.push_back({/*region_idx=*/llvm::None, token, next_op}); + } + + if (rewrite_block) UpdateFunctionType(builder, func, func_body); + + return success(); +} + +// Checks if a function call is pointing to a function with communication ops. +bool IsFunctionCallWithCommunication( + Operation* op, + const llvm::SmallDenseMap& funcs_to_rewrite) { + if (auto call = dyn_cast(op)) + return funcs_to_rewrite.count(call.callee()); + + return false; +} + +// Collects all control flow op ancestors of communication ops or function calls +// with communication ops (transitively). +void GetCommunicationControlFlowOps( + FuncOp func, + const llvm::SmallDenseMap& funcs_to_rewrite, + llvm::SmallPtrSetImpl& control_flow_ops, + llvm::SmallPtrSetImpl& control_flow_blocks) { + func.walk([&](Operation* op) { + if (IsCommunicationOp(op) || + IsFunctionCallWithCommunication(op, funcs_to_rewrite)) + if (failed(GetControlFlowAncestors(op, control_flow_ops, + control_flow_blocks))) + llvm_unreachable( + "checking original function for control flow ancestors should have " + "errored first"); + }); +} + +void LegalizeTFCommunication::runOnOperation() { + auto module = getOperation(); + llvm::SmallDenseMap funcs_to_rewrite; + if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite))) + return signalPassFailure(); + + // Module level counter to make sure Channel Id's are unique. + int64_t channel_id = 1; + OpBuilder builder(&getContext()); + for (const auto& func_and_name : funcs_to_rewrite) { + const auto& func_to_rewrite = func_and_name.getSecond(); + FuncOp func = func_to_rewrite.original; + if (failed(RewriteFunction(builder, channel_id, module, func, + funcs_to_rewrite, + func_to_rewrite.control_flow_ops, + func_to_rewrite.control_flow_blocks, + /*is_clone=*/false))) + return signalPassFailure(); + + FuncOp clone = func_and_name.getSecond().clone; + if (!clone) continue; + llvm::SmallPtrSet clone_control_flow_ops; + llvm::SmallPtrSet clone_control_flow_blocks; + GetCommunicationControlFlowOps(clone, funcs_to_rewrite, + clone_control_flow_ops, + clone_control_flow_blocks); + if (failed(RewriteFunction(builder, channel_id, module, clone, + funcs_to_rewrite, clone_control_flow_ops, + clone_control_flow_blocks, + /*is_clone=*/true))) + llvm_unreachable( + "rewriting of original function should have errored first"); + } +} + +static PassRegistration pass( + "xla-legalize-tf-communication", + "Legalize TF/XLA communication ops (TensorFlow dialect) to the HLO " + "dialect"); +} // namespace + +std::unique_ptr> CreateLegalizeTFCommunicationPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc index 09e94d9a84f..760252331e0 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_control_flow.cc @@ -119,10 +119,8 @@ void LowerIf(TF::IfOp op, ModuleOp module) { // Import the regions for both the true and false cases. These regions // must be updated to tuple the return results together and use the xla hlo // return op. - auto then_branch = module.lookupSymbol(op.then_branch()); - auto else_branch = module.lookupSymbol(op.else_branch()); - ImportXlaRegion(then_branch, &if_op.true_branch(), loc); - ImportXlaRegion(else_branch, &if_op.false_branch(), loc); + ImportXlaRegion(op.then_func(), &if_op.true_branch(), loc); + ImportXlaRegion(op.else_func(), &if_op.false_branch(), loc); // De-tuple the results of the xla hlo if result. Detuple(if_op.getResult(), op.getResults(), &builder); @@ -174,11 +172,9 @@ void LowerWhile(TF::WhileOp op, ModuleOp module) { // Import the regions for both the cond and body. These regions must be // updated to tuple the return results together and use the xla hlo return op. - auto body_branch = module.lookupSymbol(op.body()); - auto cond_branch = module.lookupSymbol(op.cond()); - - ImportXlaRegion(body_branch, &while_op.body(), loc); - ImportXlaRegion(cond_branch, &while_op.cond(), loc, /*tuple_return=*/false); + ImportXlaRegion(op.body_func(), &while_op.body(), loc); + ImportXlaRegion(op.cond_func(), &while_op.cond(), loc, + /*tuple_return=*/false); // De-tuple the results of the xla hlo while. Detuple(while_op.getResult(), op.getResults(), &builder); diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 05e061337c7..1d4c9503afa 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -128,7 +128,7 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), // return x / y; // } // -// BraodcastToDimensions is used to compute the broadcast attr to higher +// BroadcastToDimensions is used to compute the broadcast attr to higher // dimensions. This computes the broadcast of 'l' to broadcast('l', 'r') // without returning the broadcast of 'r' to broadcast('l', 'r'). // @@ -143,14 +143,14 @@ def : Pat<(TF_FloorDivOp AnyRankedTensor:$l, AnyRankedTensor:$r), (HLOClient_BroadcastCompareOp $r, (HLO_ConstOp (GetScalarOfType<0> $r)), (NullDenseIntElementsAttr), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $l, $r), HLO_COMPARISON_DIRECTION_EQ), - (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), - (HLOClient_BroadcastDivOp - (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), + (HLOClient_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), + (HLOClient_BroadcastDivOp + (HLO_NegOp:$neg (HLOClient_BroadcastAddOp (HLO_AbsOp $l), (HLOClient_BroadcastSubOp (HLO_AbsOp $r), (HLO_ConstOp (GetScalarOfType<1> $r)), (NullDenseIntElementsAttr)), (BinBroadcastDimensions $l, $r))), - (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))), + (HLO_AbsOp:$abs $r), (BinBroadcastDimensions $neg, $abs))), [(SignedIntTensor $l)]>; // Performs a substitution of FloorMod designed to correct for possibly negative @@ -175,8 +175,8 @@ def : Pat<(TF_FloorModOp AnyRankedTensor:$l, AnyRankedTensor:$r), (BinBroadcastDimensions $rem, $r_zeros), HLO_COMPARISON_DIRECTION_LT), (BinBroadcastDimensions $r_cmp, $rem_cmp), HLO_COMPARISON_DIRECTION_NE), (NullDenseIntElementsAttr)), - (HLOClient_BroadcastAddOp $r, - $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; + (HLOClient_BroadcastAddOp $r, + $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; //===----------------------------------------------------------------------===// // Logical & bitwise binary op patterns. @@ -489,7 +489,7 @@ def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, $slice_sizes)]>; //===----------------------------------------------------------------------===// -// PartitionedCall op patterns. +// PartitionedCall and LegacyCall op patterns. //===----------------------------------------------------------------------===// def ArgTypesMatchCallee : Constraint< @@ -502,6 +502,12 @@ foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { [(ArgTypesMatchCallee $op, $args, $f)]>; } +// The extra attr on this op is _disable_call_shape_inference, which we ignore +// in the bridge. +def : Pat<(TF_LegacyCallOp:$op $args, FlatSymbolRefAttr:$f, $attr), + (CallOp $f, $args), + [(ArgTypesMatchCallee $op, $args, $f)]>; + //===----------------------------------------------------------------------===// // Reverse op patterns. //===----------------------------------------------------------------------===// @@ -518,6 +524,7 @@ def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (TF_ConstOp $axis)), foreach Mapping = [ [TF_AbsOp, HLO_AbsOp], + [TF_AcosOp, HLOClient_AcosOp], [TF_CeilOp, HLO_CeilOp], [TF_ComplexAbsOp, HLO_AbsOp], [TF_CosOp, HLO_CosOp], @@ -540,6 +547,19 @@ foreach Mapping = [ (Mapping[1] $input)>; } +// Expand acos to MHLO dialect as follows: +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 +// = pi if x == -1 +def : Pat<(HLOClient_AcosOp $input), (HLO_SelectOp + (HLO_CompareOp $input, (HLO_ConstantLike<"0"> $input), + HLO_COMPARISON_DIRECTION_NE), + (HLO_MulOp (HLO_ConstantLike<"2.0f"> $input), + (HLO_Atan2Op + (HLO_SqrtOp (HLO_SubOp + (HLO_ConstantLike<"1"> $input), (HLO_MulOp $input, $input))), + (HLO_AddOp (HLO_ConstantLike<"1"> $input), $input))), + (HLO_ConstantLike<"M_PI"> $input))>; + // TODO(bixia): Lower Cast with a Complex type source operand or with // Truncate=True for floating point value conversions. def : Pat<(TF_CastOp HLO_Tensor:$arg, ConstBoolAttrFalse), @@ -557,17 +577,8 @@ foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { (HLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)]>; } -// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. -def : Pat<(TF_SignOp $x), - (HLO_SelectOp - (HLO_CompareOp - $x, - $x, - HLO_COMPARISON_DIRECTION_NE - ), - (HLO_ConstOp (ConstantSplat<"0"> $x)), - (HLO_SignOp $x) - )>; +// Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. +def : Pat<(TF_SignOp $x), (HLO_SignOp $x)>; def BothElementTypesSameWidthIntOrFloat : Constraint; +// TODO(jpienaar): Lower constant like to constant to broadcast if dynamic +// and going to MHLO. + //===----------------------------------------------------------------------===// // Random ops. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index d25b38d9ece..904b80e05b1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -24,12 +24,15 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -45,7 +48,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -70,12 +74,8 @@ limitations under the License. namespace mlir { namespace mhlo { -namespace { -template -using InlinedVector = tensorflow::gtl::InlinedVector; // non-absl ok - -static bool IsOpAllowlisted(Operation* op) { +bool IsOpAllowedTf2XlaFallback(Operation* op) { // Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels // building valid MLIR using MlirHloBuilder. // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for @@ -100,6 +100,8 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -152,9 +154,12 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -174,6 +179,7 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -182,6 +188,8 @@ static bool IsOpAllowlisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -209,8 +217,13 @@ static bool IsOpAllowlisted(Operation* op) { return ops.count(abstractOp->typeID); } +namespace { + +template +using InlinedVector = tensorflow::gtl::InlinedVector; // non-absl ok + static std::unique_ptr CreateDeviceMgr( - const std::string& device_type, const Location& loc) { + const std::string& device_type) { // Register compilation kernels for all registered XLA backends. tensorflow::XlaOpRegistry::RegisterCompilationKernels(); @@ -219,42 +232,47 @@ static std::unique_ptr CreateDeviceMgr( return absl::make_unique(std::move(device)); } -class FuncLegalizer { +class Tf2XlaRewriter { public: - static LogicalResult Legalize(FuncOp func, const std::string& device_type) { - FuncLegalizer legalizer(func, device_type); - if (failed(legalizer.PrepareParams())) return failure(); - return legalizer.Legalize(); + static LogicalResult RewriteOp(Operation* op, PatternRewriter& rewriter, + const std::string& device_type) { + Tf2XlaRewriter tf2xla_rewriter(op, rewriter, device_type); + return tf2xla_rewriter.LegalizeOp(); } private: - FuncLegalizer(FuncOp func, const std::string& device_type) - : func_(func), device_type_(device_type), hlo_builder_(func) {} + Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter, + const std::string& device_type) + : op_(op), + device_type_(device_type), + rewriter_(rewriter), + hlo_builder_(op->getName().getStringRef().str(), rewriter_, + op->getLoc()), + context_(nullptr) {} - ~FuncLegalizer() { context_->Unref(); } + ~Tf2XlaRewriter() { + if (context_) context_->Unref(); + } // Prepares OpKernelContext params common to all the ops. // Emits an error on failure. LogicalResult PrepareParams(); - // Tries to legalize supported TensorFlow ops. - // Emits an error on failure. - LogicalResult Legalize(); - // Tries to legalize the specified TensorFlow op, if supported. // // Emits an error and returns failure if an error is encountered during // conversion. Note that success return value doesn't mean successful // legalization. - LogicalResult LegalizeOp(Operation* op); + LogicalResult LegalizeOp(); // Converts the given operand to expression of kind kConstant or kXlaOp. // Emits a remark and returns expression of kind kInvalid on failure. tensorflow::XlaExpression GetExprForOperand(Value operand, Operation* op); - FuncOp func_; + Operation* op_; std::string device_type_; + PatternRewriter& rewriter_; ::xla::MlirHloBuilder hlo_builder_; tensorflow::OpOrArgLocNameMapper name_mapper_; @@ -268,15 +286,14 @@ class FuncLegalizer { tensorflow::OpKernelContext::Params params_; }; -LogicalResult FuncLegalizer::PrepareParams() { +LogicalResult Tf2XlaRewriter::PrepareParams() { // XlaCompiler within the context is only used by the functional ops to // compile functions. We are not handling those at the moment so XlaCompiler // is not required. context_ = new tensorflow::XlaContext(/*compiler=*/nullptr, &hlo_builder_); context_->Ref(); - mlir::Location loc = func_.getLoc(); - device_mgr_ = CreateDeviceMgr(device_type_, loc); + device_mgr_ = CreateDeviceMgr(device_type_); if (!device_mgr_) return failure(); // Type of params_.device is DeviceBase* so store it as Device* to access @@ -296,18 +313,16 @@ LogicalResult FuncLegalizer::PrepareParams() { device_->resource_manager(), tensorflow::XlaContext::kXlaContextResourceName, context_); if (!status.ok()) { - emitError(loc) << "failed to create XlaContext resource: " - << status.ToString(); - return failure(); + return emitError(op_->getLoc()) + << "failed to create XlaContext resource: " << status.ToString(); } params_.step_container = step_container_.get(); tensorflow::StatusOr version_or = tensorflow::GetTfGraphProducerVersion( - func_.getParentOfType()); + op_->getParentOfType()); if (!version_or.ok()) { - emitError(loc) << version_or.status().ToString(); - return failure(); + return emitError(op_->getLoc()) << version_or.status().ToString(); } flib_def_ = absl::make_unique( @@ -319,62 +334,38 @@ LogicalResult FuncLegalizer::PrepareParams() { return success(); } -LogicalResult FuncLegalizer::Legalize() { - if (func_.empty()) return success(); - - // TensorFlow functions don't use CFGs. - if (!llvm::hasSingleElement(func_)) { - emitError(func_.getLoc()) << "requires at most one block in a TF function"; - return failure(); - } - Block& block = func_.front(); - - std::vector ops; - ops.reserve(block.getOperations().size()); - for (Operation& op : block.getOperations()) { - ops.push_back(&op); - } - - for (Operation* op : ops) { - if (failed(LegalizeOp(op))) return failure(); - } - return success(); -} - -LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { - if (!IsOpAllowlisted(op)) return success(); - +LogicalResult Tf2XlaRewriter::LegalizeOp() { // Only static shaped operands are supported in XLA builders for now. - for (Type ty : op->getOperandTypes()) { + for (Type ty : op_->getOperandTypes()) { auto ranked_ty = ty.dyn_cast(); if (!ranked_ty || !ranked_ty.hasStaticShape()) { - op->emitRemark() << "lowering requires static shaped tensor operands"; - return success(); + return op_->emitRemark() + << "lowering requires static shaped tensor operands"; } } auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef( - op, name_mapper_.GetUniqueName(op), /*ignore_unregistered_attrs=*/true); + op_, name_mapper_.GetUniqueName(op_), /*ignore_unregistered_attrs=*/true); if (!nodedef_or.ok()) { - op->emitRemark() << "failed to convert op to NodeDef: " - << nodedef_or.status().ToString(); - return success(); + return op_->emitRemark() << "failed to convert op to NodeDef: " + << nodedef_or.status().ToString(); } + if (failed(PrepareParams())) return failure(); + std::shared_ptr props; tensorflow::Status status = tensorflow::NodeProperties::CreateFromNodeDef( *nodedef_or.ValueOrDie(), params_.function_library->GetFunctionLibraryDefinition(), &props); if (!status.ok()) { - op->emitRemark() << "failed to create NodeProperties: " - << status.ToString(); - return success(); + return op_->emitRemark() + << "failed to create NodeProperties: " << status.ToString(); } tensorflow::OpKernel* op_kernel_raw; status = params_.function_library->CreateKernel(props, &op_kernel_raw); if (!status.ok()) { - op->emitRemark() << "failed to create tf2xla kernel: " << status.ToString(); - return success(); + return op_->emitRemark() + << "failed to create tf2xla kernel: " << status.ToString(); } // Transfer ownership of the kernel to a local smart pointer. auto op_kernel = absl::WrapUnique(op_kernel_raw); @@ -383,9 +374,8 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { status = tensorflow::XlaOpRegistry::CompileTimeConstantInputs( *op_kernel, &required_constants); if (!status.ok()) { - op->emitRemark() << "failed to compute required constants: " - << status.ToString(); - return success(); + return op_->emitRemark() + << "failed to compute required constants: " << status.ToString(); } llvm::SmallDenseSet required_consts; required_consts.insert(required_constants.begin(), required_constants.end()); @@ -395,89 +385,87 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { InlinedVector expressions; InlinedVector tensors; InlinedVector inputs; - expressions.reserve(op->getNumOperands()); - tensors.reserve(op->getNumOperands()); - inputs.reserve(op->getNumOperands()); + expressions.reserve(op_->getNumOperands()); + tensors.reserve(op_->getNumOperands()); + inputs.reserve(op_->getNumOperands()); // Prepare the list of Tensor inputs for the kernel. - for (auto it : llvm::enumerate(op->getOperands())) { + for (auto it : llvm::enumerate(op_->getOperands())) { Value operand = it.value(); size_t idx = it.index(); - tensorflow::XlaExpression expr = GetExprForOperand(operand, op); + tensorflow::XlaExpression expr = GetExprForOperand(operand, op_); tensorflow::XlaExpression::Kind kind = expr.kind(); - if (kind == tensorflow::XlaExpression::Kind::kInvalid) return success(); + if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure(); if (required_consts.count(idx) && kind != tensorflow::XlaExpression::Kind::kConstant) { - op->emitRemark() << "lowering requires operand #" << idx - << " to be a constant"; - return success(); + return op_->emitRemark() + << "lowering requires operand #" << idx << " to be a constant"; } expressions.push_back(expr); if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) { - op->emitRemark() << "skipping legalization due to unsupported type " - << operand.getType(); - return success(); + return op_->emitRemark() + << "skipping legalization due to unsupported type " + << operand.getType(); } auto shape_or = expr.GetShape(); if (!shape_or.ok()) { - op->emitRemark() << "failed to get shape for expression. " - << expr.HumanString(); - return success(); + return op_->emitRemark() + << "failed to get shape for expression. " << expr.HumanString(); } tensors.emplace_back( device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(), shape_or.ValueOrDie()); tensorflow::Tensor& tensor = tensors.back(); - tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expr, &tensor); + tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor); inputs.emplace_back(&tensor); } params_.inputs = &inputs; params_.op_kernel = op_kernel.get(); llvm::SmallVector output_attr( - op->getNumResults()); + op_->getNumResults()); params_.output_attr_array = output_attr.data(); - hlo_builder_.setInsertionPoint(op); - hlo_builder_.SetLocation(op->getLoc()); + hlo_builder_.setInsertionPoint(op_); + hlo_builder_.SetLocation(op_->getLoc()); // Execute the kernel. - tensorflow::OpKernelContext op_context(¶ms_, op->getNumResults()); + tensorflow::OpKernelContext op_context(¶ms_, op_->getNumResults()); device_->Compute(params_.op_kernel, &op_context); if (!op_context.status().ok()) { - op->emitRemark() << "compilation to HLO failed: " - << op_context.status().ToString(); - return success(); + return op_->emitRemark() + << "compilation to HLO failed: " << op_context.status().ToString(); } // Replace uses of old results using the corresponding value after the // lowering. - for (int i = 0, e = op->getNumResults(); i < e; i++) { + llvm::SmallVector values; + values.reserve(op_->getNumResults()); + for (int i = 0, e = op_->getNumResults(); i < e; i++) { tensorflow::Tensor* output = op_context.mutable_output(i); const tensorflow::XlaExpression* expr = - tensorflow::XlaOpKernelContext::CastExpressionFromTensor(*output); + tensorflow::XlaExpression::CastExpressionFromTensor(*output); if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp) - return op->emitError( + return op_->emitError( "expects XlaExpression of kind kXlaOp in compiled output"); auto value = hlo_builder_.GetValue(expr->handle()); - mlir::OpResult old_result = op->getResult(i); + mlir::OpResult old_result = op_->getResult(i); if (value.getType() != old_result.getType()) { value = hlo_builder_.create(value, old_result.getType()); } - old_result.replaceAllUsesWith(value); + values.push_back(value); } - - op->erase(); + rewriter_.replaceOp(op_, values); return success(); } -tensorflow::XlaExpression FuncLegalizer::GetExprForOperand(Value operand, - Operation* op) { +tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand(Value operand, + Operation* op) { ElementsAttr const_attr; auto defining_op = operand.getDefiningOp(); if (defining_op && matchPattern(defining_op, m_Constant(&const_attr))) { @@ -509,6 +497,23 @@ tensorflow::XlaExpression FuncLegalizer::GetExprForOperand(Value operand, return tensorflow::XlaExpression::XlaOp(xla_op, dtype); } +class Tf2XlaRewritePattern : public RewritePattern { + public: + // Set benefit to 0 (= least benefit) so this pattern is only used as a + // fallback. + explicit Tf2XlaRewritePattern(const std::string& device_type) + : RewritePattern(0, MatchAnyOpTypeTag()), device_type_(device_type) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (!IsOpAllowedTf2XlaFallback(op)) return failure(); + return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_); + } + + private: + std::string device_type_; +}; + class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; @@ -520,7 +525,9 @@ class LegalizeTF : public PassWrapper { LegalizeTF(const LegalizeTF&) {} void runOnFunction() override { - if (failed(FuncLegalizer::Legalize(getFunction(), device_type_))) + OwningRewritePatternList patterns; + patterns.insert(device_type_); + if (failed(applyPatternsAndFoldGreedily(getFunction(), patterns))) signalPassFailure(); } @@ -529,8 +536,7 @@ class LegalizeTF : public PassWrapper { // global device type for all TensorFlow ops. Option device_type_{ *this, "device-type", - llvm::cl::desc("XLA device type for execution of TensorFlow ops. " - "Supports XLA_CPU_JIT and XLA_TPU_JIT for now.")}; + llvm::cl::desc("XLA device type for execution of TensorFlow ops.")}; }; static PassRegistration pass( @@ -539,6 +545,11 @@ static PassRegistration pass( } // end namespace +void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, + OwningRewritePatternList& patterns) { + patterns.insert(device_type.str()); +} + std::unique_ptr> createLegalizeTfWithTf2XlaPass( llvm::StringRef device_type) { return std::make_unique(device_type); diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index 519068893e7..832bad2dcc8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -33,16 +33,20 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -74,26 +78,8 @@ StatusOr> HloModuleFromProto( // Convert the MLIR `module` from HLO dialect to LHLO dialect using XLA for the // given platform. -Status ConvertModule(ModuleOp module, StringRef platform_name) { - SymbolTable symbol_table(module); - if (!symbol_table.lookup("main")) { - return ::xla::InvalidArgument( - "conversion to HLO module failed: missing main()"); - } - HloProto hlo_proto; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - ConvertMlirHloToHlo(module, &hlo_proto, - /*use_tuple_args=*/false, - /*return_tuple=*/false, - /*shape_representation_fn=*/nullptr), - "conversion to XLA HLO proto failed"); - - auto statusOrHloModule = HloModuleFromProto(hlo_proto); - TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(), - "parsing HLO proto to HLO module failed"); - std::unique_ptr hlo_module = - std::move(statusOrHloModule.ValueOrDie()); - +Status ConvertModule(std::unique_ptr hlo_module, ModuleOp module, + StringRef platform_name) { auto platform = ::xla::se::MultiPlatformManager::PlatformWithName( StringRefToView(platform_name)); if (!platform.ok()) { @@ -155,7 +141,29 @@ class XlaHloToLhloPass private: void runOnOperation() final { ModuleOp module = getOperation(); - Status status = ConvertModule(module, platform_); + + auto status = [&module, this]() -> Status { + SymbolTable symbol_table(module); + if (!symbol_table.lookup("main")) { + return ::xla::InvalidArgument( + "conversion to HLO module failed: missing main()"); + } + HloProto hlo_proto; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ConvertMlirHloToHlo(module, &hlo_proto, + /*use_tuple_args=*/false, + /*return_tuple=*/false, + /*shape_representation_fn=*/nullptr), + "conversion to XLA HLO proto failed"); + + auto statusOrHloModule = HloModuleFromProto(hlo_proto); + TF_RETURN_WITH_CONTEXT_IF_ERROR(statusOrHloModule.status(), + "parsing HLO proto to HLO module failed"); + std::unique_ptr hlo_module = + std::move(statusOrHloModule.ValueOrDie()); + + return ConvertModule(std::move(hlo_module), module, platform_); + }(); if (!status.ok()) { module.emitError() << status.ToString(); return signalPassFailure(); @@ -272,7 +280,6 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, } return Status::OK(); } - TF_ASSIGN_OR_RETURN(Type out_type, ::xla::ConvertShapeToType( current_shape, builder_)); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice, @@ -283,11 +290,35 @@ Status LhloDialectEmitter::CreateView(const HloInstruction* instr, return Status::OK(); } + auto out_memref_type = out_type.dyn_cast(); + if (!out_memref_type) + return tensorflow::errors::Internal( + "Expected memref type when creating a view for leaf type of a tuple."); + Value byte_shift = builder_.create(alloc.getLoc(), slice.offset()); - values->push_back(builder_.create(builder_.getUnknownLoc(), out_type, - alloc, byte_shift, - /*sizes=*/ValueRange{})); + + xla::Shape physical_shape = + xla::ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + current_shape); + TF_ASSIGN_OR_RETURN( + Type physical_out_type, + ::xla::ConvertShapeToType(physical_shape, builder_)); + + // TODO(timshen): revisit location handling. + Location loc = builder_.getUnknownLoc(); + + // ViewOp only takes memrefs without affine maps (layouts). Let ViewOp produce + // the physical shape (where dimensions are ordered in major to minor) first, + // then follow up with a StaticMemRefCastOp to cast the resulting memref to + // the original layout. + Value result = + builder_.create(loc, physical_out_type, alloc, byte_shift, + /*sizes=*/ValueRange{}); + if (physical_out_type != out_type) + result = builder_.create(loc, out_memref_type, + result); + values->push_back(result); return Status::OK(); } @@ -333,40 +364,43 @@ Status LhloDialectEmitter::Initialize() { for (const BufferAllocation& alloc : assignment_.Allocations()) ordered_allocations.push_back(&alloc); - // Sort the rather arbitrarily ordered allocations to match the input/output - // parameters. Specifically We want to sort buffer allocations in the - // following order: - // * Parameters always order before non-parameters. - // * Different parameters order by parameter number. - // * Different allocations for the same parameter order by the shape index. - // - // TODO(timshen): there should be only one non-parameter buffer, the temp - // buffer. Check on that. - const auto allocation_comparator = [](const BufferAllocation* lhs, - const BufferAllocation* rhs) { - if (lhs->is_entry_computation_parameter() != - rhs->is_entry_computation_parameter()) { - return lhs->is_entry_computation_parameter() > - rhs->is_entry_computation_parameter(); - } - if (lhs->is_entry_computation_parameter()) { - return std::tuple( - lhs->parameter_number(), lhs->param_shape_index()) < - std::tuple( - rhs->parameter_number(), rhs->param_shape_index()); - } - return false; - }; + if (computation_.IsEntryComputation()) { + // Sort the rather arbitrarily ordered allocations to match the input/output + // parameters. Specifically We want to sort buffer allocations in the + // following order: + // * Parameters always order before non-parameters. + // * Different parameters order by parameter number. + // * Different allocations for the same parameter order by the shape index. + // + // TODO(timshen): there should be only one non-parameter buffer, the temp + // buffer. Check on that. + const auto allocation_comparator = [](const BufferAllocation* lhs, + const BufferAllocation* rhs) { + if (lhs->is_entry_computation_parameter() != + rhs->is_entry_computation_parameter()) { + return lhs->is_entry_computation_parameter() > + rhs->is_entry_computation_parameter(); + } + if (lhs->is_entry_computation_parameter()) { + return std::tuple( + lhs->parameter_number(), lhs->param_shape_index()) < + std::tuple( + rhs->parameter_number(), rhs->param_shape_index()); + } + return false; + }; - std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(), - allocation_comparator); + std::stable_sort(ordered_allocations.begin(), ordered_allocations.end(), + allocation_comparator); + } // The function signature will be composed of: // - one memref for each of the parameters. // - one memref for each other buffer allocation. llvm::SmallVector args_attrs; for (const BufferAllocation* alloc : ordered_allocations) { - if (alloc->is_entry_computation_parameter()) { + if (computation_.IsEntryComputation() && + alloc->is_entry_computation_parameter()) { const ::xla::Shape& buffer_shape = ::xla::ShapeUtil::GetSubshape( computation_.parameter_instruction(alloc->parameter_number()) ->shape(), @@ -379,6 +413,8 @@ Status LhloDialectEmitter::Initialize() { block->addArgument(arg_type); allocations_[alloc] = block->getArguments().back(); args_attrs.emplace_back(); + args_attrs.back().set(builder_.getIdentifier("lmhlo.alloc"), + builder_.getIndexAttr(alloc->index())); args_attrs.back().set(builder_.getIdentifier("lmhlo.params"), builder_.getIndexAttr(alloc->parameter_number())); } else { @@ -427,6 +463,22 @@ Status HloToLhloModule(const BufferAssignment& assignment, return computation->AcceptOrdered(&emitter, ordering); } +mlir::OwningModuleRef HloTextToLhloTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context) { + StatusOr> maybe_module = + xla::ParseAndReturnUnverifiedModule( + absl::string_view(input.data(), input.size())); + TF_CHECK_OK(maybe_module.status()); + + mlir::OwningModuleRef module = + mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); + + TF_CHECK_OK( + ConvertModule(maybe_module.ConsumeValueOrDie(), module.get(), "Host")); + + return module; +} + static PassRegistration registration( "xla-hlo-to-lhlo-with-xla", "Emit LHLO from HLO using the existing XLA implementation"); diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index ca40eb5804c..bdc977616b1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -127,6 +127,9 @@ tensorflow::Status HloToLhloModule(const ::xla::BufferAssignment& assignment, const ::xla::HloModule& hlo_module, ModuleOp module); +OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input, + mlir::MLIRContext* context); + } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_MHLO_TO_LHLO_WITH_XLA_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index bc261324055..45166941620 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -18,6 +18,9 @@ limitations under the License. #include +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { @@ -33,14 +36,31 @@ namespace mhlo { /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is /// false, emits an error if there is any operation that can't be legalized. +/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization +/// patterns from TF2XLA fallback for provided device type (see +/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not +/// used. std::unique_ptr> createLegalizeTFPass( - bool allow_partial_conversion = false, bool legalize_chlo = true); + bool allow_partial_conversion = false, bool legalize_chlo = true, + llvm::Optional tf2xla_fallback_device_type = llvm::None); /// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the /// specified device type. std::unique_ptr> createLegalizeTfWithTf2XlaPass( llvm::StringRef device_type); +/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list. +void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type, + OwningRewritePatternList& patterns); + +/// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern +/// list. +void PopulateLegalizeTfPatterns(MLIRContext* context, + OwningRewritePatternList* patterns); + +/// Checks whether the op is supported by the Tf2Xla fallback for legalization. +bool IsOpAllowedTf2XlaFallback(Operation* op); + /// Lowers from TF dialect's control flow to HLO dialect's control flow. std::unique_ptr> createLegalizeTFControlFlowPass(); @@ -48,8 +68,18 @@ std::unique_ptr> createLegalizeTFControlFlowPass(); /// dialect using the conversion patterns registered by the HLO dialect. When /// allow_partial_conversion is false, emits an error if there is any operation /// that can't be legalized. -LogicalResult legalizeTF(Operation* op, bool allow_partial_conversion = false, - bool legalize_chlo = true); +/// When `tf2xla_fallback_device_type` is not `None`, also uses legalization +/// patterns from TF2XLA fallback for provided device type (see +/// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is not +/// used. +LogicalResult legalizeTF( + Operation* op, bool allow_partial_conversion = false, + bool legalize_chlo = true, + llvm::Optional tf2xla_fallback_device_type = llvm::None); + +// Legalizes TF/XLA communication ops (TF dialect) to HLO dialect communication +// ops. +std::unique_ptr> CreateLegalizeTFCommunicationPass(); } // namespace mhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index b684abde7a5..afc36916348 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -145,11 +145,37 @@ Shape TypeToShape(mlir::Type type) { // For the primitive type case, the shape of the memref is similar to the // vector type case (i.e., it is, modulo the layout, the same dimensions // and primitive type). - // Currently we only return shapes for identity affine maps. - // TODO(andydavis) Map affine map layout function to XLA layout. - if (m.getAffineMaps().empty() || - (m.getAffineMaps().size() == 1 && m.getAffineMaps()[0].isIdentity())) + if (m.getAffineMaps().empty()) return ShapeUtil::MakeShape(primitive_type, span); + + if (m.getAffineMaps().size() == 1) { + llvm::SmallVector strides; + int64_t offset; + if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {}; + + llvm::SmallVector, 4> strides_with_indices; + for (const auto& e : llvm::enumerate(strides)) { + strides_with_indices.push_back({e.value(), e.index()}); + } + std::sort(strides_with_indices.begin(), strides_with_indices.end()); + + llvm::SmallVector minor_to_major; + int64_t stride = 1; + for (const auto& pr : strides_with_indices) { + minor_to_major.push_back(pr.second); + + // Either the affine map is not perfectly strided, or the dimensions + // recovered from strides don't match the actual dimensions in shapes. + if (stride != pr.first) return {}; + + stride *= m.getShape()[pr.second]; + } + + llvm::SmallVector dimensions(m.getShape().begin(), + m.getShape().end()); + return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, + minor_to_major); + } break; } case mlir::StandardTypes::RankedTensor: { diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index b2a7cb85686..a4a2bc42d99 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -178,5 +179,22 @@ TEST(TypeToShapeTest, ConvertWithShapeRepresentationFn) { EXPECT_EQ(captured_tensor_shape, tensorflow::TensorShape({1, 2, 3})); } +TEST(TypeToShapeTest, ConvertMemRefToShape) { + Shape shape = ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32, {10, 20, 30}, + {2, 0, 1}); + MLIRContext context; + mlir::Builder builder(&context); + + StatusOr mlir_type = + ConvertShapeToType(shape, builder); + ASSERT_TRUE(mlir_type.ok()); + mlir::Type type = mlir_type.ConsumeValueOrDie(); + Shape converted = TypeToShape(type); + EXPECT_TRUE(ShapeUtil::Equal( + converted, ShapeUtil::MakeShapeWithLayout(PrimitiveType::F32, + {10, 20, 30}, {2, 0, 1}))); + EXPECT_TRUE(ShapeUtil::Equal(converted, shape)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc index 211470bd41e..158671a6242 100644 --- a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -183,3 +184,8 @@ static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate( static mlir::TranslateToMLIRRegistration HloTextToHloMlirTranslate( "hlo-text-to-mlir-hlo", xla::HloTextToMlirHloTranslateFunction); + +// MHLO doesn't support explicit layouts, while XLA service does. +// TODO(timshen): remove it once MHLO supports explicit layouts. +static mlir::TranslateToMLIRRegistration HloTextToLhloMlirTranslate( + "hlo-text-to-lhlo", mlir::HloTextToLhloTranslateFunction); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 0b5a6c147dc..924834fc0fc 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -21,10 +21,6 @@ package_group( includes = [ "//tensorflow/compiler/tf2xla:internal", ], - packages = [ - # To pass open source testing in the pip Kokoros. - "//bazel_pip/tensorflow/compiler/tests/...", - ], ) package_group( @@ -34,7 +30,6 @@ package_group( ], packages = [ # To pass open source testing in the pip Kokoros. - "//bazel_pip/tensorflow/compiler/tests/...", "//platforms/xla/tests/neural_nets", ], ) @@ -128,7 +123,6 @@ tf_xla_py_test( name = "adagrad_da_test", size = "small", srcs = ["adagrad_da_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -166,7 +160,6 @@ tf_xla_py_test( srcs = ["add_n_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -331,6 +324,7 @@ tf_xla_py_test( name = "self_adjoint_eig_op_test", size = "medium", srcs = ["self_adjoint_eig_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -694,7 +688,6 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -865,6 +858,7 @@ tf_xla_py_test( size = "medium", timeout = "long", srcs = ["matrix_diag_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -934,9 +928,8 @@ tf_xla_py_test( name = "pooling_ops_test", size = "medium", srcs = ["pooling_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", - shard_count = 10, + shard_count = 20, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -956,7 +949,7 @@ tf_xla_py_test( srcs = ["pooling_ops_3d_test.py"], enable_mlir_bridge = True, python_version = "PY3", - shard_count = 10, + shard_count = 20, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -1193,6 +1186,10 @@ tf_xla_py_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "noasan", + "nomsan", + "notsan", + "optonly", ], deps = [ ":xla_test", @@ -1208,6 +1205,7 @@ tf_xla_py_test( name = "spacetobatch_op_test", size = "medium", srcs = ["spacetobatch_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 3, tags = [ @@ -1244,7 +1242,6 @@ tf_xla_py_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "config-cuda-only", @@ -1305,7 +1302,6 @@ tf_xla_py_test( srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "config-cuda-only", @@ -1334,10 +1330,10 @@ tf_xla_py_test( srcs = ["tensor_list_ops_test.py"], # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", # b/162025277 ], deps = [ ":xla_test", @@ -1889,7 +1885,6 @@ tf_xla_py_test( name = "special_math_test", size = "medium", srcs = ["special_math_test.py"], - enable_mlir_bridge = True, shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index 19a1d62cddd..9c941e791ee 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -8,6 +8,7 @@ load( "tf_cuda_tests_tags", "tf_exec_properties", ) +load("//tensorflow:tensorflow.bzl", "py_test") def all_backends(): b = ["cpu"] + plugins.keys() @@ -121,7 +122,7 @@ def tf_xla_py_test( updated_name = updated_name[:-5] updated_name += "_mlir_bridge_test" - native.py_test( + py_test( name = updated_name, srcs = srcs, srcs_version = "PY2AND3", diff --git a/tensorflow/compiler/tests/case_test.py b/tensorflow/compiler/tests/case_test.py index 3b2dff537da..4da9c4fac7a 100644 --- a/tensorflow/compiler/tests/case_test.py +++ b/tensorflow/compiler/tests/case_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for while loops in XLA.""" +"""Tests for case statements in XLA.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 520348e0f8a..eef9d24766d 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -311,7 +311,7 @@ class EagerFunctionTest(xla_test.XLATestCase): if 'GPU' in self.device: # TODO(b/32333178) self.skipTest('Current implementation of RandomStandardNormal kernel ' - 'is very slow on GPU, and has been blacklisted.') + 'is very slow on GPU, and has been denylisted.') with self.test_scope(): data_format = 'channels_last' conv = convolutional.Conv2D( diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 326c3ec4929..9590688fda7 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -30,7 +30,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import image_ops @@ -775,7 +774,6 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase): class NonMaxSuppressionTest(xla_test.XLATestCase): - @test_util.disable_mlir_bridge("%1") def testNMS128From1024(self): num_boxes = 1024 boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") @@ -810,7 +808,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(indices_tf.size, max_output_size) - @test_util.disable_mlir_bridge("%1") def testNMS3From6Boxes(self): # Three boxes are selected based on IOU. boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], @@ -852,7 +849,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 3) self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) - @test_util.disable_mlir_bridge("%1") def testNMS3Then2WithScoreThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. @@ -895,7 +891,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 2) self.assertAllClose(indices_tf[:num_valid], [3, 0]) - @test_util.disable_mlir_bridge("%1") def testNMS3Then1WithScoreMaxThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. @@ -939,7 +934,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 1) self.assertAllClose(indices_tf[:num_valid], [3]) - @test_util.disable_mlir_bridge("%1") def testSelectFromContinuousOverLap(self): # Tests that a suppressed box does not itself suppress other boxes. @@ -984,7 +978,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): - @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1022,7 +1015,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1056,7 +1048,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output) self.assertAllEqual([3, 3], num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSSingleFrom6Max3(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1087,7 +1078,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2], indices_output) self.assertAllEqual(3, num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSSingleFrom6NoPad(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1117,7 +1107,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2, 4, 5], indices_output) self.assertAllEqual(5, num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSBatchDimsFrom6Max3(self): boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1151,7 +1140,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output) self.assertAllEqual([[3, 3]], num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSScoreThresholdFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1187,7 +1175,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSUnsortedInputFrom6(self): boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]], @@ -1224,7 +1211,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSNoncanonicalizedInputFrom6(self): boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4], [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]], @@ -1262,7 +1248,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1298,7 +1283,6 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) - @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6DynamicInput(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], diff --git a/tensorflow/compiler/tests/pooling_ops_3d_test.py b/tensorflow/compiler/tests/pooling_ops_3d_test.py index a833daa39be..9eda74b55a9 100644 --- a/tensorflow/compiler/tests/pooling_ops_3d_test.py +++ b/tensorflow/compiler/tests/pooling_ops_3d_test.py @@ -75,9 +75,6 @@ class Pooling3DTest(xla_test.XLATestCase): actual = vals.flatten() self.assertAllClose(expected, actual) - @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" - " doesn't support all paddings and data " - "formats") def testAvgPool3dValidPadding(self): expected_output = [20.5, 21.5, 22.5] self._VerifyValues( @@ -88,9 +85,6 @@ class Pooling3DTest(xla_test.XLATestCase): padding="VALID", expected=expected_output) - @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" - " doesn't support all paddings and data " - "formats") def testAvgPool3dSamePadding(self): expected_output = [20.5, 21.5, 22.5, 26.5, 27.5, 28.5] self._VerifyValues( @@ -101,9 +95,6 @@ class Pooling3DTest(xla_test.XLATestCase): padding="SAME", expected=expected_output) - @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" - " doesn't support all paddings and data " - "formats") def testAvgPool3dSamePaddingDifferentStrides(self): expected_output = [1.5, 4.5, 7.5, 17.5, 20.5, 23.5, 33.5, 36.5, 39.5] self._VerifyValues( diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py index 293e1010b08..d9393387c0d 100644 --- a/tensorflow/compiler/tests/pooling_ops_test.py +++ b/tensorflow/compiler/tests/pooling_ops_test.py @@ -268,9 +268,6 @@ class PoolingTest(xla_test.XLATestCase): expected=[1, 3, 9, 11]) # Average pooling - @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" - " doesn't support all paddings and data " - "formats") def testAvgPoolValidPadding(self): expected_output = [7, 8, 9] self._VerifyValues( @@ -281,9 +278,6 @@ class PoolingTest(xla_test.XLATestCase): padding="VALID", expected=expected_output) - @test_util.disable_mlir_bridge("TODO(b/159812644): AvgPool TF to HLO lowering" - " doesn't support all paddings and data " - "formats") def testAvgPoolSamePadding(self): expected_output = [7., 8., 9., 11.5, 12.5, 13.5] self._VerifyValues( diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index d50fdec7c63..838718aa1e3 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -129,42 +129,35 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" - # Only bfloat16 is implemented. - bfloat16 = dtypes.bfloat16.as_numpy_dtype - if bfloat16 not in self.numeric_types: - return - - with self.session() as sess: - p = array_ops.placeholder(dtypes.bfloat16) - with self.test_scope(): - topk = nn_ops.top_k(p, k=4) - results = sess.run( - topk, - {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=bfloat16)}) - self.assertAllEqual( - np.array([3., 0., 0., 0.], dtype=bfloat16), results[0]) - self.assertEqual(list([3, 0, 2, 6]), list(results[1])) + supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + for dtype in supported_types.intersection(self.numeric_types): + with self.session() as sess: + p = array_ops.placeholder(dtype) + with self.test_scope(): + topk = nn_ops.top_k(p, k=4) + results = sess.run( + topk, + {p: np.array([0., -0., 0., 3., -0., -4., 0., -0.], dtype=dtype)}) + self.assertAllEqual(np.array([3., 0., 0., 0.], dtype=dtype), results[0]) + self.assertEqual(list([3, 0, 2, 6]), list(results[1])) def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" - # Only bfloat16 is implemented. - bfloat16 = dtypes.bfloat16.as_numpy_dtype - if bfloat16 not in self.numeric_types: - return - - with self.session() as sess: - p = array_ops.placeholder(dtypes.bfloat16) - with self.test_scope(): - topk = nn_ops.top_k(p, k=6) - results = sess.run(topk, { - p: np.array( - [1, 2, float("inf"), -float("inf"), -1, -2], dtype=bfloat16) - }) - self.assertAllEqual( - np.array( - [float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], - dtype=bfloat16), results[0]) - self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) + supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + for dtype in supported_types.intersection(self.numeric_types): + with self.session() as sess: + p = array_ops.placeholder(dtype) + with self.test_scope(): + topk = nn_ops.top_k(p, k=6) + results = sess.run(topk, { + p: + np.array([1, 2, float("inf"), -float("inf"), -1, -2], + dtype=dtype) + }) + self.assertAllEqual( + np.array([float("inf"), 2.0, 1.0, -1.0, -2.0, -float("inf")], + dtype=dtype), results[0]) + self.assertEqual(list([2, 1, 0, 4, 5, 3]), list(results[1])) def testInTopK(self): supported_types = set([np.int32, np.int64]) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 162693a9eb1..eb022da6895 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import unittest import numpy as np +import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test @@ -90,6 +91,10 @@ class UnaryOpsTest(xla_test.XLATestCase): self.assertAllClose(result, expected, rtol, atol) self.assertAllEqual(np.sort(result), result) + def AssertAllEqual(self, result, expected, rtol, atol): + """Tests that result and expeted are exactly equal.""" + self.assertAllEqual(result, expected) + @test_util.disable_mlir_bridge( "MlirHloBuilder::Iota missing required for xla::Diag") def testAllTypeOps(self): @@ -435,8 +440,12 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.sign, - np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0]], dtype=dtype), - expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0]], dtype=dtype)) + np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0, + float("nan")]], + dtype=dtype), + expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0, + float("nan")]], + dtype=dtype)) self._assertOpOutputMatchesExpected( math_ops.is_finite, @@ -775,6 +784,10 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype])) + @test_util.disable_mlir_bridge( + "TF_PopulationCount is missing and is required to translate to " + "xla::PopulationCount." + ) def testIntOps(self): for dtype in self.int_types: self._assertOpOutputMatchesExpected( @@ -782,6 +795,38 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([0, -1, 1, 16, 42], dtype=dtype), expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) + # Test population_count for array inputs. + raw_inputs = [ + 0, 1, -1, 3, -3, 5, -5, 14, -14, 127, 128, 255, 256, 65535, 65536, + 2**31 - 1, 2**31, 2**32 - 1, 2**32, -2**32 + 1, -2**32, -2**63 + 1, + 2**63 - 1 + ] + # Only choose inputs which fit in the int dtype. + raw_inputs = list( + filter(lambda x: np.iinfo(dtype).min <= x <= np.iinfo(dtype).max, + raw_inputs)) + inputs = np.array(raw_inputs, dtype=dtype) + + def count_bits(x): + return sum(bin(z).count("1") for z in six.iterbytes(x.tobytes())) + + truth = [count_bits(x) for x in inputs] + self._assertOpOutputMatchesExpected( + bitwise_ops.population_count, + inputs, + expected=np.array(truth, dtype=np.uint8), + equality_test=self.AssertAllEqual) + + # Test population_count for scalar inputs. + for raw_inp in raw_inputs: + inp = dtype(raw_inp) + truth = count_bits(inp) + self._assertOpOutputMatchesExpected( + bitwise_ops.population_count, + inp, + expected=np.uint8(truth), + equality_test=self.AssertAllEqual) + def testNumericOps(self): for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( @@ -923,16 +968,22 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array([1, 0x100000003f800000], np.uint64)) def testInvertPermutation(self): - self._assertOpOutputMatchesExpected( - array_ops.invert_permutation, - np.array([1, 2, 0], np.int32), - expected=np.array([2, 0, 1], dtype=np.int32)) + for np_dtype in [np.int32, np.int64]: + self._assertOpOutputMatchesExpected( + array_ops.invert_permutation, + np.array([1, 2, 0], np_dtype), + expected=np.array([2, 0, 1], dtype=np_dtype)) def testInvertPermutationTwiceIsNoop(self): - self._assertOpOutputMatchesExpected( - lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)), - np.array([1, 2, 0], np.int32), - expected=np.array([1, 2, 0], dtype=np.int32)) + + def invert_twice(x): + return array_ops.invert_permutation(array_ops.invert_permutation(x)) + + for np_dtype in [np.int32, np.int64]: + self._assertOpOutputMatchesExpected( + invert_twice, + np.array([1, 2, 0], np_dtype), + expected=np.array([1, 2, 0], dtype=np_dtype)) def testRank(self): rank_op = lambda x: array_ops.rank_internal(x, optimize=False) diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index f5f63cb60aa..3b057ed8b17 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -236,9 +236,7 @@ class XLATestCase(test.TestCase): @contextlib.contextmanager def test_scope(self): - """Test scope that runs tests on a Tensorflow/XLA device. - - Uses a compilation_scope() to mark operators to compile. + """Test scope that runs tests on `self.device`. Yields: A scope to apply to the operators under test. diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index c9210a1a1e7..c4fc3e4f5da 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/config.pb.h" // NOLINT @@ -332,7 +333,6 @@ void UpdateToEngineNode(const std::vector& infos, Status CreateTRTNode(const ConversionParams& params, const std::vector& infos, int pos, int max_batch_size, Graph* graph, - nvinfer1::IGpuAllocator* alloc, std::vector* engine_nodes) { const auto& info = infos.at(pos); std::vector input_shape_protos; @@ -428,16 +428,30 @@ Status CreateTRTNode(const ConversionParams& params, // Build the engine and get its serialized representation. string segment_string; if (info.engine_type == EngineInfo::EngineType::TRTStatic) { + std::pair device_allocator = + GetDeviceAndAllocator(params, info); + int cuda_device_id = 0; + std::unique_ptr trt_allocator; + if (device_allocator.first >= 0) { + cuda_device_id = device_allocator.first; + trt_allocator.reset(new TRTDeviceAllocator(device_allocator.second)); + } else { + // The value in trt_allocator is a nullptr and cudamalloc will be used. + LOG_WARNING_WITH_PREFIX << "Can't identify the cuda device. Running on " + "device 0 and use cudamalloc as an allocator"; + } + cudaSetDevice(cuda_device_id); + auto trt_logger = GetLoggerRegistry()->LookUp(params.trt_logger_name); - // Create static engine for fp32/fp16 mode. + // Create static engines with precision_mode fp32/fp16. TrtUniquePtrType engine; - // TODO(sami): What happens if 1st dim is not batch? TF_RETURN_IF_ERROR(ConvertGraphDefToEngine( info.segment_graph_def, calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode, max_batch_size, info.max_workspace_size_bytes, input_shapes, trt_logger, - alloc, /*calibrator=*/nullptr, &engine, info.use_calibration, - params.use_implicit_batch, /*convert_successfully=*/nullptr, + trt_allocator.get(), /*calibrator=*/nullptr, &engine, + info.use_calibration, params.use_implicit_batch, + /*convert_successfully=*/nullptr, /*profile=*/nullptr)); TrtUniquePtrType engine_data(engine->serialize()); segment_string = string(static_cast(engine_data->data()), @@ -793,13 +807,27 @@ Status ConvertAfterShapes(const ConversionParams& params) { } } - // Create a TRT node for each segment using its EngineInfo. - int old_cuda_device = 0; - auto err = cudaGetDevice(&old_cuda_device); - if (err != cudaSuccess) { - LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err); + // Save the cuda device if we may need to switch to another cuda device to + // build static engines. + absl::optional old_cuda_device = absl::nullopt; + if (!params.is_dyn_op) { + int cuda_device_id; + cudaError_t cuda_error = cudaGetDevice(&cuda_device_id); + if (cuda_error != cudaSuccess) { + LOG_WARNING_WITH_PREFIX << "Couldn't get current device: " + << cudaGetErrorString(cuda_error); + } else { + VLOG(1) << "Current cuda device is " << cuda_device_id; + old_cuda_device = cuda_device_id; + } } - VLOG(1) << "Current cuda device is " << old_cuda_device; + + auto restore_cuda_device = gtl::MakeCleanup([old_cuda_device] { + if (old_cuda_device.has_value()) { + cudaSetDevice(old_cuda_device.value()); + } + }); + std::vector engine_nodes; engine_nodes.resize(engine_segments.size()); for (int i = 0; i < engine_segments.size(); ++i) { @@ -813,24 +841,8 @@ Status ConvertAfterShapes(const ConversionParams& params) { 2.0; VLOG(1) << "Assigned " << engine.max_workspace_size_bytes << " bytes to " << engine.engine_name; - // The allocator is used to build the engine. The build and the built engine - // will be destroyed after we get the serialized engine string, so it's fine - // to use unique_ptr here. - std::unique_ptr alloc; - auto device_alloc = GetDeviceAndAllocator(params, engine); - int cuda_device_id = 0; - if (device_alloc.first >= 0) { - cuda_device_id = device_alloc.first; - alloc.reset(new TRTDeviceAllocator(device_alloc.second)); - } else { - // Setting allocator as nullptr should get revert to the cudamalloc - LOG_WARNING_WITH_PREFIX - << "Can't identify the cuda device. Running on device 0 "; - } - cudaSetDevice(cuda_device_id); - auto status = - CreateTRTNode(params, engine_segments, i, params.max_batch_size, &graph, - alloc.get(), &engine_nodes); + auto status = CreateTRTNode(params, engine_segments, i, + params.max_batch_size, &graph, &engine_nodes); string msg = StrCat("segment ", i, " consisting of ", converted_segments.at(i).size(), " nodes by ", @@ -859,7 +871,6 @@ Status ConvertAfterShapes(const ConversionParams& params) { } } } - cudaSetDevice(old_cuda_device); graph.ToGraphDef(params.output_graph_def); VLOG(1) << "Returning from conversion"; return Status::OK(); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index 54fb1d56441..3b0553426c0 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -71,7 +71,7 @@ class FakeCluster : public grappler::Cluster { } private: - const DeviceSet* device_set_; + const DeviceSet* device_set_ = nullptr; }; TEST(ConvertGraphTest, GetDeviceAndAllocator) { diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 369b339d01a..f80c0f42eca 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -2410,6 +2410,40 @@ Status ConvertTranspose(OpConverterParams* params) { return Status::OK(); } +Status ConvertShape(OpConverterParams* params) { + const auto& inputs = params->inputs; + TF_RETURN_IF_ERROR( + CheckInputsWeights(*params, {{"input", TrtInputArg::kBoth}})); + if (params->use_implicit_batch) { + return errors::Unimplemented( + "Shape is only supported for explicit batch mode."); + } + if (HasStaticShape(inputs.at(0).GetTrtDims())) { + if (params->validation_only) return Status::OK(); + nvinfer1::Dims input_dims = inputs.at(0).GetTrtDims(); + nvinfer1::Dims output_dims{1, {input_dims.nbDims}}; + // Create a const node with the values of output_dims + TRT_ShapedWeights weight = params->weight_store->GetTempWeights( + nvinfer1::DataType::kINT32, output_dims); + int32* values_ptr = static_cast(weight.GetValues()); + std::copy(input_dims.d, input_dims.d + input_dims.nbDims, values_ptr); + auto output = params->converter->CreateConstantLayer(weight, output_dims); + params->outputs->push_back(TRT_TensorOrWeights(output)); + return Status::OK(); + } +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + if (params->validation_only) return Status::OK(); + nvinfer1::IShapeLayer* shape_layer = + params->converter->network()->addShape(*inputs.at(0).tensor()); + TFTRT_RETURN_ERROR_IF_NULLPTR(shape_layer, params->node_def.name()); + params->outputs->push_back(TRT_TensorOrWeights(shape_layer->getOutput(0))); + return Status::OK(); +#else + return errors::Unavailable( + "Shape op conversion requires TensorRT 6 or above"); +#endif +} + Status ConvertReshape(OpConverterParams* params) { const auto& inputs = params->inputs; TF_RETURN_IF_ERROR( @@ -3510,8 +3544,13 @@ Status ConvertPool(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}})); - TF_RETURN_IF_ERROR( - AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); +#if IS_TRT_VERSION_GE(5, 1, 0, 0) + std::set allowed_types{DataType::DT_FLOAT, DataType::DT_HALF, + DataType::DT_INT8}; +#else + std::set allowed_types{DataType::DT_FLOAT, DataType::DT_HALF}; +#endif + TF_RETURN_IF_ERROR(AllowDataTypes(*params, allowed_types)); nvinfer1::PoolingType type; if (node_def.op() == "MaxPool") { type = nvinfer1::PoolingType::kMAX; @@ -3744,6 +3783,7 @@ Status ConvertActivation(OpConverterParams* params) { params->converter->network()->addActivation(*inputs.at(0).tensor(), op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setName(node_def.name().c_str()); // Set parameters. #if IS_TRT_VERSION_GE(5, 1, 2, 0) if (node_def.op() == "Elu") { @@ -3844,9 +3884,10 @@ Status ConvertRelu6(OpConverterParams* params) { nvinfer1::IActivationLayer* layer = params->converter->network()->addActivation( *inputs.at(0).tensor(), nvinfer1::ActivationType::kCLIP); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); layer->setAlpha(0.0f); layer->setBeta(6.0f); - TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setName(node_def.name().c_str()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); @@ -4402,6 +4443,7 @@ Status ConvertUnary(OpConverterParams* params) { nvinfer1::IUnaryLayer* layer = params->converter->network()->addUnary(*tensor, op_pair->second); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setName(node_def.name().c_str()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); // Set quantization ranges. @@ -4479,7 +4521,7 @@ Status ConvertReduce(OpConverterParams* params) { int trt_axis; TF_RETURN_IF_ERROR( ConvertAxis(tf_axes_list[i], tensor->getDimensions().nbDims, - node_def.name(), /*use_implicit_batch=*/true, &trt_axis)); + node_def.name(), params->use_implicit_batch, &trt_axis)); axes |= (1 << trt_axis); } @@ -4941,7 +4983,18 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { node_def.name()); } nvinfer1::ITensor* tensor = inputs.at(0).tensor(); - + if (!params->use_implicit_batch && tensor->getDimensions().d[1] == -1) { + // This check is to make sure that channel dimension is known during + // conversion. + // + // We check this only in explicit batch mode and reject an op with unknown + // channel dimension during segmentation. In implicit batch mode we have + // known shapes during conversion even though the shapes may not be known + // during segmentation (see the actual argument for input_shapes when + // ConvertGraphDefToEngine is called from TRTEngineOp::BuildEngine). + return errors::InvalidArgument("Channel dimension must be static, at ", + node_def.name()); + } // Check parameter types auto parameter_type = inputs.at(1).weights().TrtDType(); if ((parameter_type != nvinfer1::DataType::kFLOAT) && @@ -5039,6 +5092,7 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { combined_scale_weights.GetTrtWeights(), dummy_power_weights.GetTrtWeights()); TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + layer->setName(node_def.name().c_str()); nvinfer1::ITensor* output_tensor = layer->getOutput(0); params->outputs->push_back(TRT_TensorOrWeights(output_tensor)); return Status::OK(); @@ -5958,6 +6012,7 @@ static void RegisterValidatableOpConverters( (*registration)[pool_op_type] = ConvertPool3D; } #endif + (*registration)["Shape"] = ConvertShape; (*registration)["Rsqrt"] = ConvertRsqrt; (*registration)["Slice"] = ConvertSlice; (*registration)["Softmax"] = ConvertSoftmax; diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 52d05ff8225..aeae44a5562 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -1309,7 +1309,8 @@ std::vector GetDataAsFloat(InputOutputData& data) { class OpConverterTest : public ::testing::Test { public: OpConverterTest() - : scope_(Scope::NewRootScope()), allocator_(new GpuManagedAllocator()) { + : tensor_buffer_allocator_(new GpuManagedAllocator()), + scope_(Scope::NewRootScope()) { QCHECK_EQ(0, cudaStreamCreate(&stream_)); Reset(); } @@ -1341,7 +1342,7 @@ class OpConverterTest : public ::testing::Test { // Constructs a flat tensor with 'vals' in Unified Memory. template Tensor AsTensor(gtl::ArraySlice vals) { // non-absl ok - Tensor ret(allocator_.get(), DataTypeToEnum::value, + Tensor ret(tensor_buffer_allocator_.get(), DataTypeToEnum::value, {static_cast(vals.size())}); std::copy_n(vals.data(), vals.size(), ret.flat().data()); return ret; @@ -1351,7 +1352,7 @@ class OpConverterTest : public ::testing::Test { template Tensor AsTensor(gtl::ArraySlice vals, // non-absl ok const TensorShape& shape) { - Tensor ret(allocator_.get(), DataTypeToEnum::value, + Tensor ret(tensor_buffer_allocator_.get(), DataTypeToEnum::value, {static_cast(vals.size())}); CHECK(ret.CopyFrom(AsTensor(vals), shape)); return ret; @@ -1363,7 +1364,8 @@ class OpConverterTest : public ::testing::Test { template Tensor AsTensor(std::vector vals, const std::vector input_dims, DataType tf_type) { - Tensor ret(allocator_.get(), tf_type, {static_cast(vals.size())}); + Tensor ret(tensor_buffer_allocator_.get(), tf_type, + {static_cast(vals.size())}); if (tf_type == DT_FLOAT) { auto conv_vals = CastTestVector(vals); std::copy_n(conv_vals.data(), conv_vals.size(), ret.flat().data()); @@ -1646,13 +1648,15 @@ class OpConverterTest : public ::testing::Test { Logger logger_; TrtUniquePtrType engine_; cudaStream_t stream_; - // Used to create placeholders with shape and data type information. The - // created placeholders will be used as inputs to the node to be verified, - // thus we need the shape and data type information to get a non-empty - // GraphProperties. + std::unique_ptr tensor_buffer_allocator_; + // The scope that contains the graph being converted. Because + // tensor_buffer_allocator_ provides the storage for tensor contents that are + // represented as attributes for graph nodes within scope_, + // tensor_buffer_allocator_ needs to be available when destructing scope_. + // Therefore, scope_ comes after tensor_buffer_allocator_ in the class member + // field list. Scope scope_; std::unordered_map node_inputs_; - std::unique_ptr allocator_; }; // General test parameters to be used with ops that take a single input tensor. @@ -1781,7 +1785,8 @@ class ParameterizedOpConverterTestBase void BuildAndRun(const string& name, const std::vector>& expected_output_dims, const Status& expected_runtime_status, - const std::vector>>& matcher) { + const std::vector>>& matcher, + const std::vector& out_tf_types = {}) { TensorShape shape; const int n_output = expected_output_dims.size(); ASSERT_EQ(n_output, matcher.size()); @@ -1790,12 +1795,14 @@ class ParameterizedOpConverterTestBase TF_EXPECT_OK( TensorShapeUtils::MakeShape(expected_output_dims[i], &shape)); string out_name = (n_output == 1) ? name : StrCat(name, ":", i); - InputOutputData data{out_name, - ConstructTensor(shape.num_elements(), 0, tf_type)}; + DataType out_tf_type = + out_tf_types.size() > i ? out_tf_types[i] : tf_type; + InputOutputData data{ + out_name, ConstructTensor(shape.num_elements(), 0, out_tf_type)}; output_data.push_back(data); } - ASSERT_FALSE(input_data_.empty()); - const int batch_size = input_data_[0].tensor.shape().dim_size(0); + const int batch_size = + input_data_.empty() ? 1 : input_data_[0].tensor.shape().dim_size(0); Status stat = OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size); ASSERT_EQ(expected_runtime_status.ok(), stat.ok()) @@ -1820,13 +1827,15 @@ class ParameterizedOpConverterTestBase const std::vector& expected_output_dims, const Status& expected_conversion_status, const Status& expected_runtime_status, - const Matcher>& matcher) { + const Matcher>& matcher, + const std::vector& out_tf_types = {}) { RunValidationAndConversion(node_def, expected_conversion_status, name.c_str(), expected_output_dims); if (expected_conversion_status.ok()) { BuildAndRun(name, std::vector>({expected_output_dims}), expected_runtime_status, - std::vector>>({matcher})); + std::vector>>({matcher}), + out_tf_types); } } @@ -2011,6 +2020,142 @@ TEST_F(OpConverterTest, ConvertConst) { TestConvertConst(this); } +template +NodeDef CreateFusedBatchNormOp(DataType tf_type, std::string data_format, + bool is_training, float epsilon) { + Scope s = Scope::NewRootScope(); + auto x = ops::Placeholder(s.WithOpName("x"), tf_type); + auto scale = ops::Placeholder(s.WithOpName("scale"), tf_type); + auto offset = ops::Placeholder(s.WithOpName("offset"), tf_type); + auto mean = ops::Placeholder(s.WithOpName("mean"), tf_type); + auto variance = ops::Placeholder(s.WithOpName("variance"), tf_type); + typename T::Attrs attrs; + attrs.data_format_ = data_format; + attrs.is_training_ = is_training; + if (epsilon > 0) { + attrs.epsilon_ = epsilon; + } else { + EXPECT_GE(epsilon, 0); + } + return T(s.WithOpName("my_batchnorm"), x, scale, offset, mean, variance, + attrs) + .operation.node() + ->def(); +} + +TEST_P(OpConverterTest1, ConvertFusedBatchNorm) { + using OpFunc = std::function; + std::vector get_node_def_vec{ + CreateFusedBatchNormOp, + CreateFusedBatchNormOp, + CreateFusedBatchNormOp}; + + struct TestParam { + std::string data_format; + int tensor_input_idx; // Index of an input that will be provided as tensor. + bool is_training; + float epsilon; + Status conversion_status; + bool keep_channel_unknown; + }; + + struct NodeInput { + std::string name; + std::vector dims; + std::vector val; + }; + std::vector node_input{ + {"x", {2, 3, 2, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}, + {"scale", {3}, {7, 8, 9}}, + {"offset", {3}, {10, 20, 30}}, + {"mean", {3}, {1, 2, 3}}, + {"variance", {3}, {4, 5, 6}}}; + + std::vector expected_output{10.0, 13.495633, 23.574135, 27.148273, + 37.342354, 41.013527, 30.9738, 34.469433, + 45.018955, 48.59309, 59.369415, 63.04059}; + for (auto get_node_def : get_node_def_vec) { + NodeDef tmp_node_def = get_node_def(tf_type, "NCHW", true, 0); + std::string op_name = tmp_node_def.op(); + std::vector test_param{ + {"NHWC", 0, false, 0, + errors::Unimplemented(StrCat( + op_name, " only supports data_format=NCHW, at my_batchnorm"))}, + {"NCHW", 0, true, 0, + errors::Unimplemented(StrCat( + op_name, " only supports is_training=false, at my_batchnorm"))}, + {"NCHW", 1, false, 0, + errors::Unimplemented(StrCat("The input \"scale\" for ", op_name, + " must be a constant, at my_batchnorm"))}, + {"NCHW", 2, false, 0, + errors::Unimplemented(StrCat("The input \"offset\" for ", op_name, + " must be a constant, at my_batchnorm"))}, + {"NCHW", 3, false, 0, + errors::Unimplemented(StrCat("The input \"mean\" for ", op_name, + " must be a constant, at my_batchnorm"))}, + {"NCHW", 4, false, 0, + errors::Unimplemented(StrCat("The input \"variance\" for ", op_name, + " must be a constant, at my_batchnorm"))}, + {"NCHW", 0, false, 0.01}}; // The last one is the only test that runs. + if (trt_mode == TrtTestMode::kDynamicShape) { + test_param.push_back( + {"NCHW", 0, false, 0.01, + errors::InvalidArgument( + "Channel dimension must be static, at my_batchnorm"), + true}); + } + for (auto p : test_param) { + Reset(); + NodeDef node_def = + get_node_def(tf_type, p.data_format, p.is_training, p.epsilon); + for (int i = 0; i < node_input.size(); i++) { + if (i == 0 || i == p.tensor_input_idx) { + // The first input (x) is always added as a tensor, and it hase shape + // NCHW. The other inputs are per channel values (1D, size C). + // + // In implicit batch mode, it is not possible to add any of the 1D + // inputs as a tensor: the first dim is always treated as batch dim in + // implicit batch mode, and that has to agree for all tensors. We have + // two input tensors with shapes NCHW and C and in general N != C. + // The converter already picked up N from the fist input, and reports + // an error when we try to add any other tensors with not matching + // first dim. + // + // This restriction does not apply in explicit batch mode: the tensors + // can have different first dim. The converter still expects that only + // the first arg is a tensor. TODO(tfeher) Check if one can relax this + // restriction. + Status expected_status = + (i != 0 && trt_mode == TrtTestMode::kImplicitBatch) + ? errors::InvalidArgument( + StrCat("Batch size doesn't match for tensor ", + node_input[i].name, + ": Provided batch size does not match " + "converter batch size: 3 vs 2")) + : Status::OK(); + std::vector partial_input_shape; + if (i == 0 && trt_mode == TrtTestMode::kDynamicShape && + !p.keep_channel_unknown) { + // keep channel dim static (known) + partial_input_shape.resize(4, -1); + partial_input_shape[1] = node_input[i].dims[1]; + } + AddTestTensor(node_input[i].name, node_input[i].dims, tf_type, + node_input[i].val, partial_input_shape, + expected_status); + + } else { + AddTestWeights(node_input[i].name, node_input[i].dims, + node_input[i].val, tf_type); + } + } + TestOpConverter("my_batchnorm", node_def, node_input[0].dims, + p.conversion_status, Status::OK(), + ArrayFloatNear(expected_output)); + } + } +} // namespace convert + TEST_P(OpConverterTest1, ConvertTranspose) { // Get the NodeDef for Transpose. Scope s = Scope::NewRootScope(); @@ -2169,6 +2314,52 @@ TEST_F(OpConverterTest, ConvertReshape) { } } +TEST_P(OpConverterTest1, ConvertShape) { + // Get the NodeDef for Shape op. + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto shape = ops::Shape(s.WithOpName("my_shape"), input); + const NodeDef& node_def = shape.operation.node()->def(); + + Status conversion_status = + (trt_mode == TrtTestMode::kImplicitBatch) + ? errors::Unimplemented( + "Shape is only supported for explicit batch mode.") + : Status::OK(); + std::vector test_params = { + TestParamBase{{1, 2, 3}, {}, {3}, {}, conversion_status}, + // Add input as weight (we use non empty param ({1}) to trigger this). + TestParamBase{{1, 2, 3}, {}, {3}, {1}, conversion_status}, + }; + + auto input_is_weight = [](const TestParamBase p) { return !p.param.empty(); }; + for (auto p : test_params) { + SCOPED_TRACE(p); + Reset(); + // The number of elements of the input tensor. We leave it 0 in case we do + // not need to add an input tensor. This happens in explicit batch mode: the + // shape is known at conversion time and therefore the shape is added to the + // network as a constant layer. In this case the single node network that + // we use for the unit test have no actual input tensor when it is converted + // to a TensorRT network. + int n_elements = 0; + if (input_is_weight(p) || trt_mode != TrtTestMode::kExplicitBatch) { + // Calculate the number of elements for adding input data. + n_elements = std::accumulate(p.input_dims.begin(), p.input_dims.end(), 1, + std::multiplies()); + } + std::vector input_val(n_elements, 1); + if (!input_is_weight(p)) { + AddTestTensor("input", p.input_dims, input_val); + } else { + AddTestWeights("input", p.input_dims, input_val, tf_type); + } + TestOpConverter("my_shape", node_def, p.expected_output_dims, p.status, + p.runtime_status, ElementsAreArray(p.input_dims), + {DT_INT32}); + } +} + // Helper function for testing MatMul and BatchMatMul // get_matmul corresponds to the function used to generate the node. It should // accept (DataType, transpose_a, transpose_b) as parameters. @@ -4039,72 +4230,81 @@ TEST_P(OpConverterTest1, ConvertConv2D) { // Ok. std::vector ok_params = { - // Basic - TestParams{/*input_dims=*/{1, 1, 2, 3}, - /*input=*/{0, 1, 2, 3, 3, 4}, - /*filter_dims=*/{1, 2, 1, 1}, - /*filter=*/{-1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NCHW", - /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 1, 2, 2}, - /*expected_output=*/{1, 1, 0, 1}}, - // SAME padding (Asymmetric) - TestParams{/*input_dims=*/{1, 1, 2, 3}, - /*input=*/{0, 1, 2, 3, 3, 4}, - /*filter_dims=*/{1, 2, 1, 1}, - /*filter=*/{-1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*padding=*/"SAME", - /*data_format=*/"NCHW", - /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 1, 2, 3}, - /*expected_output=*/{1, 1, -2, 0, 1, -4}}, - // SAME padding (Symmetric) - TestParams{/*input_dims=*/{1, 1, 2, 3}, - /*input=*/{0, 1, 2, 3, 3, 4}, - /*filter_dims=*/{1, 3, 1, 1}, - /*filter=*/{-1, 0, 1}, - /*strides=*/{1, 1, 1, 1}, - /*padding=*/"SAME", - /*data_format=*/"NCHW", - /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 1, 2, 3}, - /*expected_output=*/{1, 2, -1, 3, 1, -3}}, - // NHWC - TestParams{/*input_dims=*/{1, 2, 3, 1}, - /*input=*/{0, 1, 2, 3, 3, 4}, - /*filter_dims=*/{1, 2, 1, 1}, - /*filter=*/{-1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NHWC", - /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 2, 2, 1}, - /*expected_output=*/{1, 1, 0, 1}}, - // Dilated - TestParams{/*input_dims=*/{1, 1, 2, 3}, - /*input=*/{0, 1, 2, 3, 3, 4}, - /*filter_dims=*/{1, 2, 1, 1}, - /*filter=*/{-1, 1}, - /*strides=*/{1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NCHW", - /*dilations=*/{1, 1, 1, 2}, - /*expected_output_dims=*/{1, 1, 2, 1}, - /*expected_output=*/{2, 1}}, - // Strided - TestParams{/*input_dims=*/{1, 1, 2, 4}, - /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, - /*filter_dims=*/{1, 2, 1, 1}, - /*filter=*/{-1, 1}, - /*strides=*/{1, 1, 1, 2}, - /*padding=*/"VALID", - /*data_format=*/"NCHW", - /*dilations=*/{1, 1, 1, 1}, - /*expected_output_dims=*/{1, 1, 2, 2}, - /*expected_output=*/{1, 0, 1, 3}}, +// TODO(b/162447069): Enable the test parameters for TRT 7.1.3.x. +#if !IS_TRT_VERSION_GE(7, 1, 3, 0) + // Basic + TestParams{/*input_dims=*/{1, 1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 1, 2, 2}, + /*expected_output=*/{1, 1, 0, 1}}, +#endif +// TODO(b/162448349): Enable the test parameters for TRT 7.1.3.x. +#if !IS_TRT_VERSION_GE(7, 1, 3, 0) + // SAME padding (Asymmetric) + TestParams{/*input_dims=*/{1, 1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 1, 2, 3}, + /*expected_output=*/{1, 1, -2, 0, 1, -4}}, + // SAME padding (Symmetric) + TestParams{/*input_dims=*/{1, 1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 3, 1, 1}, + /*filter=*/{-1, 0, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 1, 2, 3}, + /*expected_output=*/{1, 2, -1, 3, 1, -3}}, +#endif +// TODO(b/162447069): Enable the test parameters for TRT 7.1.3.x. +#if !IS_TRT_VERSION_GE(7, 1, 3, 0) + // NHWC + TestParams{/*input_dims=*/{1, 2, 3, 1}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NHWC", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 2, 2, 1}, + /*expected_output=*/{1, 1, 0, 1}}, + // Dilated + TestParams{/*input_dims=*/{1, 1, 2, 3}, + /*input=*/{0, 1, 2, 3, 3, 4}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 2}, + /*expected_output_dims=*/{1, 1, 2, 1}, + /*expected_output=*/{2, 1}}, + // Strided + TestParams{/*input_dims=*/{1, 1, 2, 4}, + /*input=*/{0, 1, 2, 2, 3, 4, 4, 7}, + /*filter_dims=*/{1, 2, 1, 1}, + /*filter=*/{-1, 1}, + /*strides=*/{1, 1, 1, 2}, + /*padding=*/"VALID", + /*data_format=*/"NCHW", + /*dilations=*/{1, 1, 1, 1}, + /*expected_output_dims=*/{1, 1, 2, 2}, + /*expected_output=*/{1, 0, 1, 3}}, +#endif }; for (int i = 0; i < ok_params.size(); i++) { @@ -4589,41 +4789,72 @@ TEST_F(OpConverterTest, ConvertConv3D) { ElementsAreArray(ok_params[i].expected_output)); } } +#endif -TEST_F(OpConverterTest, ConvertPool3D) { - // Get nodedef for MaxPool3D and AvgPool3D layers. - auto get_pool3d_nodedef = [](std::vector ksize = {1, 1, 1, 1, 1}, - std::vector strides = {1, 1, 1, 1, 1}, - string padding = "SAME", - string data_format = "NCDHW", - const bool is_max_pooling = true) -> NodeDef { - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - +template +NodeDef CreatePoolOp(DataType tf_type, std::vector ksize, + std::vector strides, string padding, + string data_format) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + typename T::Attrs attrs; + attrs.data_format_ = data_format; + return T(s.WithOpName("my_pool"), input, ksize, strides, padding, attrs) + .operation.node() + ->def(); +} +TEST_P(OpConverterTest1, ConvertPool) { + // Get nodedef for MaxPool and AvgPool layers (2D or 3D). + auto get_pool_nodedef = + [](DataType tf_type, int nDim, std::vector ksize = {}, + std::vector strides = {}, string padding = "SAME", + string data_format = "", const bool is_max_pooling = true) -> NodeDef { + if (ksize.empty()) { + ksize = nDim == 2 ? std::vector{1, 1, 1, 1} + : std::vector{1, 1, 1, 1, 1}; + } + if (strides.empty()) { + strides = nDim == 2 ? std::vector{1, 1, 1, 1} + : std::vector{1, 1, 1, 1, 1}; + } + if (data_format == "") { + data_format = nDim == 2 ? "NCHW" : "NCDHW"; + } if (is_max_pooling) { - ops::MaxPool3D::Attrs attrs = - ops::MaxPool3D::Attrs().DataFormat(data_format); - auto pool3d = ops::MaxPool3D(s.WithOpName("my_maxpool3d"), input, ksize, - strides, padding, attrs); - return pool3d.operation.node()->def(); + if (nDim == 3) { + return CreatePoolOp(tf_type, ksize, strides, padding, + data_format); + } else { + return CreatePoolOp(tf_type, ksize, strides, padding, + data_format); + } } else { - ops::AvgPool3D::Attrs attrs = - ops::AvgPool3D::Attrs().DataFormat(data_format); - auto pool3d = ops::AvgPool3D(s.WithOpName("my_avgpool3d"), input, ksize, - strides, padding, attrs); - return pool3d.operation.node()->def(); + if (nDim == 3) { + return CreatePoolOp(tf_type, ksize, strides, padding, + data_format); + } else { + return CreatePoolOp(tf_type, ksize, strides, padding, + data_format); + } } }; - { +#if IS_TRT_VERSION_GE(6, 0, 0, 0) + std::vector test_nDims{2, 3}; +#else + std::vector test_nDims{2}; +#endif + + for (int nDim : test_nDims) { // Input is weights, should fail. Reset(); - NodeDef node_def = get_pool3d_nodedef(); + NodeDef node_def = get_pool_nodedef(tf_type, nDim); - AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); - RunValidationAndConversion( - node_def, error::UNIMPLEMENTED, - "The input \"input\" for MaxPool3D must be a tensor, at my_maxpool3d"); + AddTestWeights("input", {1, 1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); + RunValidationAndConversion(node_def, error::UNIMPLEMENTED, + StrCat("The input \"input\" for ", node_def.op(), + " must be a tensor, at my_pool") + .c_str()); } struct TestParams { @@ -4633,150 +4864,110 @@ TEST_F(OpConverterTest, ConvertPool3D) { std::vector strides; string padding; string data_format; - bool is_max_pooling; std::vector expected_output_dims; - std::vector expected_output; + // The expected outputs for the following operations: MaxPool2D, AvgPool2D, + // MaxPool3D, AvgPool3D + std::vector> expected_outputs; }; - // Start here - const std::vector common_array{-4, 2, 15, 3, 6, -3, 22, 1, 88, + // We use common_input as the input to test both 2D and 3D pooling operations, + // to simplify TestParams. For 2D operations, only the first 1/3 of the values + // are used. + const std::vector common_input{-4, 2, 15, 3, 6, -3, 22, 1, 88, 56, 36, 1, 1, 105, 1, 16, -28, 1, 42, 9, 3, 1, 7, 1, 11, 61, 5}; + // The output of 2D ops for the case where the op is equivalent to the + // identity op. + const std::vector common_2d_output{-4, 2, 15, 3, 6, -3, 22, 1, 88}; std::vector ok_params = { // Basic - just 1x1 max pooling - input = output - TestParams{/*input_dims=*/{1, 3, 3, 3}, - /*input=*/common_array, - /*ksize=*/{1, 1, 1, 1, 1}, - /*strides=*/{1, 1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NCDHW", - /*is_max_pooling=*/true, - /*expected_output_dims=*/{1, 3, 3, 3}, - /*expected_output=*/common_array}, - // Basic - just 1x1 avg pooling - input = output - TestParams{/*input_dims=*/{1, 3, 3, 3}, - /*input=*/common_array, - /*ksize=*/{1, 1, 1, 1, 1}, - /*strides=*/{1, 1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NCDHW", - /*is_max_pooling=*/false, - /*expected_output_dims=*/{1, 3, 3, 3}, - /*expected_output=*/common_array}, + TestParams{ + /*input_dims=*/{1, 1, 3, 3, 3}, + /*input=*/common_input, + /*ksize=*/{1, 1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1, 1}, + /*padding=*/"VALID", + /*data_format=*/"NCDHW", + /*expected_output_dims=*/{1, 1, 3, 3, 3}, + /*expected_outputs=*/ + {common_2d_output, common_2d_output, common_input, common_input}}, // Basic - just 1x1 max pooling - input = output, SAME padding - TestParams{/*input_dims=*/{1, 3, 3, 3}, - /*input=*/common_array, - /*ksize=*/{1, 1, 1, 1, 1}, - /*strides=*/{1, 1, 1, 1, 1}, - /*padding=*/"SAME", - /*data_format=*/"NCDHW", - /*is_max_pooling=*/true, - /*expected_output_dims=*/{1, 3, 3, 3}, - /*expected_output=*/common_array}, - // Basic - just 1x1 avg pooling - input = output, SAME padding - TestParams{/*input_dims=*/{1, 3, 3, 3}, - /*input=*/common_array, - /*ksize=*/{1, 1, 1, 1, 1}, - /*strides=*/{1, 1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NCDHW", - /*is_max_pooling=*/false, - /*expected_output_dims=*/{1, 3, 3, 3}, - /*expected_output=*/common_array}, - // 3x3 max pooling - TestParams{/*input_dims=*/{1, 3, 3, 3}, - /*input=*/common_array, + TestParams{ + /*input_dims=*/{1, 1, 3, 3, 3}, + /*input=*/common_input, + /*ksize=*/{1, 1, 1, 1, 1}, + /*strides=*/{1, 1, 1, 1, 1}, + /*padding=*/"SAME", + /*data_format=*/"NCDHW", + /*expected_output_dims=*/{1, 1, 3, 3, 3}, + /*expected_outputs=*/ + {common_2d_output, common_2d_output, common_input, common_input}}, + // 3x3 pooling NCDHW + TestParams{/*input_dims=*/{1, 1, 3, 3, 3}, + /*input=*/common_input, /*ksize=*/{1, 1, 3, 3, 3}, /*strides=*/{1, 1, 1, 1, 1}, /*padding=*/"VALID", /*data_format=*/"NCDHW", - /*is_max_pooling=*/true, - /*expected_output_dims=*/{1, 1, 1, 1}, - /*expected_output=*/{105}}, - // 3x3 avg pooling - TestParams{/*input_dims=*/{1, 3, 3, 3}, - /*input=*/common_array, - /*ksize=*/{1, 1, 3, 3, 3}, + /*expected_output_dims=*/{1, 1, 1, 1, 1}, + /*expected_outputs=*/{{88}, {14.444445}, {105}, {17}}}, + // 3x3 pooling, NDHWC + TestParams{/*input_dims=*/{1, 3, 3, 3, 1}, + /*input=*/common_input, + /*ksize=*/{1, 3, 3, 3, 1}, /*strides=*/{1, 1, 1, 1, 1}, /*padding=*/"VALID", + /*data_format=*/"NDHWC", + /*expected_output_dims=*/{1, 1, 1, 1, 1}, + /*expected_outputs=*/{{88}, {14.444445}, {105}, {17}}}, + // Strided + TestParams{/*input_dims=*/{1, 1, 3, 3, 3}, + /*input=*/{1, 0, 2, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 7, 0, 8}, + /*ksize=*/{1, 1, 1, 1, 1}, + /*strides=*/{1, 1, 2, 2, 2}, + /*padding=*/"VALID", /*data_format=*/"NCDHW", - /*is_max_pooling=*/false, - /*expected_output_dims=*/{1, 1, 1, 1}, - /*expected_output=*/{17}}, - // 3x3 max pooling, NDHWC - TestParams{/*input_dims=*/{3, 3, 3, 1}, - /*input=*/common_array, - /*ksize=*/{1, 3, 3, 3, 1}, - /*strides=*/{1, 1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NDHWC", - /*is_max_pooling=*/true, - /*expected_output_dims=*/{1, 1, 1, 1}, - /*expected_output=*/{105}}, - // 3x3 avg pooling, NDHWC - TestParams{/*input_dims=*/{3, 3, 3, 1}, - /*input=*/common_array, - /*ksize=*/{1, 3, 3, 3, 1}, - /*strides=*/{1, 1, 1, 1, 1}, - /*padding=*/"VALID", - /*data_format=*/"NDHWC", - /*is_max_pooling=*/false, - /*expected_output_dims=*/{1, 1, 1, 1}, - /*expected_output=*/{17}}, - // Strided max - TestParams{ - /*input_dims=*/{1, 3, 3, 3}, - /*input=*/{1, 0, 2, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 7, 0, 8}, - /*ksize=*/{1, 1, 1, 1, 1}, - /*strides=*/{1, 1, 2, 2, 2}, - /*padding=*/"VALID", - /*data_format=*/"NCDHW", - /*is_max_pooling=*/true, - /*expected_output_dims=*/{1, 2, 2, 2}, - /*expected_output=*/{1, 2, 3, 4, 5, 6, 7, 8} // Should only pick up - // the corners - }, - // Strided avg - TestParams{ - /*input_dims=*/{1, 3, 3, 3}, - /*input=*/{1, 0, 2, 0, 0, 0, 3, 0, 4, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 5, 0, 6, 0, 0, 0, 7, 0, 8}, - /*ksize=*/{1, 1, 1, 1, 1}, - /*strides=*/{1, 1, 2, 2, 2}, - /*padding=*/"VALID", - /*data_format=*/"NCDHW", - /*is_max_pooling=*/false, - /*expected_output_dims=*/{1, 2, 2, 2}, - /*expected_output=*/{1, 2, 3, 4, 5, 6, 7, 8} // Should only pick up - // the corners - }}; + /*expected_output_dims=*/{1, 1, 2, 2, 2}, + /*expected_outputs=*/ + {{1, 2, 3, 4}, // Should only pick up the corners + {1, 2, 3, 4}, + {1, 2, 3, 4, 5, 6, 7, 8}, + {1, 2, 3, 4, 5, 6, 7, 8}}}, + }; - for (int i = 0; i < ok_params.size(); i++) { - Reset(); - NodeDef node_def = get_pool3d_nodedef( - ok_params[i].ksize, ok_params[i].strides, ok_params[i].padding, - ok_params[i].data_format, ok_params[i].is_max_pooling); - AddTestTensor("input", ok_params[i].input_dims); - RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - string expected_node_name = - ok_params[i].is_max_pooling ? "my_maxpool3d" : "my_avgpool3d"; - TF_EXPECT_OK(GetTensorOrWeights(expected_node_name, &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims, - output.tensor()->getDimensions()); - - const DataVec input_data{{"input", AsTensor(ok_params[i].input)}}; - DataVec output_data{ - {expected_node_name, - ConstructTensor(ok_params[i].expected_output.size())}}; - TF_EXPECT_OK(BuildAndRun(input_data, &output_data)); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAreArray(ok_params[i].expected_output)); + for (auto p : ok_params) { + int test_counter = 0; + for (int nDim : test_nDims) { + auto input = p.input; + auto input_dims = p.input_dims; + auto ksize = p.ksize; + auto strides = p.strides; + auto expected_output_dims = p.expected_output_dims; + std::string data_format = p.data_format; + if (nDim == 2) { + input.resize(9); + data_format = p.data_format == "NDHWC" ? "NHWC" : "NCHW"; + // Remove one of the spatial dimensions + input_dims.erase(input_dims.begin() + 2); + ksize.erase(ksize.begin() + 2); + strides.erase(strides.begin() + 2); + expected_output_dims.erase(expected_output_dims.begin() + 2); + } + for (bool is_max_pooling : {true, false}) { + Reset(); + NodeDef node_def = + get_pool_nodedef(tf_type, nDim, ksize, strides, p.padding, + data_format, is_max_pooling); + AddTestTensor("input", input_dims, input); + TestOpConverter("my_pool", node_def, expected_output_dims, Status::OK(), + Status::OK(), + ElementsAreArray(p.expected_outputs.at(test_counter))); + test_counter++; + } + } } } -#endif // IS_TRT_VERSION_GE(6, 0, 0, 0) TEST_F(OpConverterTest, ConvertTopK) { // TODO(tmorris): This test isn't setting the input dtype properly. TopK with @@ -5052,6 +5243,148 @@ TEST_P(OpConverterTest3, ConvertGather) { } } +template +NodeDef CreateReduceOp(DataType tf_type, bool keep_dims) { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto axis = ops::Placeholder(s.WithOpName("axis"), DT_INT32); + typename OpType::Attrs op_attrs; + op_attrs.keep_dims_ = keep_dims; + auto op = OpType(s.WithOpName("my_reduce"), input, axis, op_attrs); + return op.operation.node()->def(); +} + +// Applies reduction op on sub-sequences of input +// output[i] = reduce(input[m * i : m * (i +1)]) +std::vector CalcReduce(string op_name, std::vector input, int m, + float (*op)(float, float), float init) { + std::vector output(input.size() / m); + for (int i = 0; i < output.size(); i++) { + auto begin = input.begin() + i * m; + auto end = input.begin() + (i + 1) * m; + output[i] = std::accumulate(begin, end, init, op); + if (op_name == "Mean") { + output[i] /= m; + } + } + return output; +} +TEST_P(OpConverterTest1, ConvertReduce) { + { + // Input is weights, should fail. + Reset(); + const NodeDef node_def = CreateReduceOp(tf_type, false); + AddTestWeights("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); + AddTestWeights("axis", {1}, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"input\" for Sum must be a tensor, at my_reduce"); + } + { + // Axis is weights, should fail. + Reset(); + const NodeDef node_def = CreateReduceOp(tf_type, false); + AddTestTensor("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2}); + AddTestTensor("axis", {1}, DT_INT32, {1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "The input \"axis\" for Sum must be a constant, at my_reduce"); + } + using OpFunc = std::function; + using ValFunc = float (*)(float, float); + struct ReduceTestDescriptor { + string name; + OpFunc get_node; + ValFunc val_func; + float init_val; + }; + std::vector op_test_info{ + {"Sum", CreateReduceOp, [](float x, float y) { return x + y; }, + 0}, + {"Prod", CreateReduceOp, + [](float x, float y) { return x * y; }, 1}, + {"Mean", CreateReduceOp, + [](float x, float y) { return x + y; }, 0}, + {"Min", CreateReduceOp, + [](float x, float y) { return y < x ? y : x; }, 1000}, + {"Max", CreateReduceOp, + [](float x, float y) { return x < y ? y : x; }, -1000}}; + + std::vector input_values{1, 2, 3, 4, 5, 6}; + struct TestParams { + std::vector input_dims; + std::vector input_values; + // Helper array contains the same elements as input but permuted in a way + // that the reduction can be calculated over contiguous elements using + // CalcReduce + std::vector helper_array; + std::vector axis; + int stride; // product of input_dims along axis + Status conversion_status; + }; + std::vector params{ + // Out of range tests + TestParams{{2, 3, 1}, input_values, input_values, {3}, 3}, + TestParams{{2, 3, 1}, input_values, input_values, {-4}, 3}, + // Ok tests + TestParams{{2, 3, 1}, input_values, {1, 4, 2, 5, 3, 6}, {0}, 2}, + TestParams{{2, 3, 1}, input_values, input_values, {1}, 3}, + TestParams{{2, 3, 1}, input_values, input_values, {2}, 1}, + TestParams{{2, 3, 1}, input_values, input_values, {0, 1}, 6}, + // Ok tests with negative axis values + TestParams{{2, 3, 1}, input_values, {1, 4, 2, 5, 3, 6}, {-3}, 2}, + TestParams{{2, 3, 1}, input_values, input_values, {-2}, 3}, + TestParams{{2, 3, 1}, input_values, input_values, {-1}, 1}, + TestParams{{2, 3, 1}, input_values, input_values, {-3, 1}, 6}, + }; + + for (bool keep_dims : {false, true}) { + for (auto& op : op_test_info) { + for (auto p : params) { + SCOPED_TRACE(StrCat(op.name, keep_dims ? "keep_dims" : "")); + Reset(); + NodeDef node_def = op.get_node(tf_type, keep_dims); + + AddTestTensor("input", p.input_dims, p.input_values); + AddTestWeights("axis", {static_cast(p.axis.size())}, + p.axis); + std::vector expected_output_dims(p.input_dims); + + // Set expected output dim and conversion error messages + for (int ax : p.axis) { + int rank = p.input_dims.size(); + if (ax >= rank || ax < -rank) { + p.conversion_status = + errors::InvalidArgument("Axis value of ", ax, + " is out of bounds, must be in " + "range [", + -rank, ", ", rank, "), at my_reduce"); + } else { + int ax_positive = ax >= 0 ? ax : ax + rank; + // Zero marks elements that we will remove later. + expected_output_dims[ax_positive] = keep_dims ? 1 : 0; + if (trt_mode == TrtTestMode::kImplicitBatch && + (ax == 0 || ax == -rank)) { + p.conversion_status = errors::Unimplemented( + "TensorRT does not allow manipulation of the batch " + "dimension, at my_reduce"); + } + } + } + expected_output_dims.erase(std::remove(expected_output_dims.begin(), + expected_output_dims.end(), 0), + expected_output_dims.end()); + VLOG(2) << "out dims " << expected_output_dims; + std::vector expected_values = CalcReduce( + op.name, p.helper_array, p.stride, op.val_func, op.init_val); + TestOpConverter("my_reduce", node_def, expected_output_dims, + p.conversion_status, Status::OK(), + ArrayFloatNear(expected_values)); + } + } + } +} + NodeDef CreateCastOp(DataType tf_type) { Scope s = Scope::NewRootScope(); auto input = ops::Placeholder(s.WithOpName("input"), DT_HALF); @@ -6442,19 +6775,22 @@ void TestConvertResize(OpConverterTest* test) { typedef typename EnumToDataType::Type CType; std::vector> params{ - { - /*input_dims=*/{1, 2, 1}, // H, W, C - /*output_resize_dims=*/{2, 3}, // H_out, W_out - /*input_values=*/CastTestVector({2.0f, -1.0f}), - /*align_corners=*/false, - /*expected_output_dims=*/{2, 3, 1}, // H, W, C - /*expected_nearest_output_values=*/ - CastTestVector({2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f}), - /*expected_bilinear_output_values=*/ - CastTestVector({2.0f, 0.f, -1.0f, 2.0f, 0.f, -1.0f}), - }, - { - /*input_dims=*/{1, 2, 1}, // H, W, C +// TODO(b/162442839): Enable the test parameters for TRT 7.1.3.x. +#if !IS_TRT_VERSION_GE(7, 1, 3, 0) + { + /*input_dims=*/{1, 2, 1}, // H, W, C + /*output_resize_dims=*/{2, 3}, // H_out, W_out + /*input_values=*/CastTestVector({2.0f, -1.0f}), + /*align_corners=*/false, + /*expected_output_dims=*/{2, 3, 1}, // H, W, C + /*expected_nearest_output_values=*/ + CastTestVector({2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f}), + /*expected_bilinear_output_values=*/ + CastTestVector({2.0f, 0.f, -1.0f, 2.0f, 0.f, -1.0f}), + }, +#endif + { + /*input_dims=*/{1, 2, 1}, // H, W, C /*output_resize_dims=*/{2, 3}, // H_out, W_out /*input_values=*/CastTestVector({2.0f, -1.0f}), /*align_corners=*/true, @@ -6463,7 +6799,8 @@ void TestConvertResize(OpConverterTest* test) { CastTestVector({2.0f, 2.0f, -1.0f, 2.0f, 2.0f, -1.0f}), /*expected_bilinear_output_values=*/ CastTestVector({2.0f, 0.5f, -1.0f, 2.0f, 0.5f, -1.0f}), - }}; + } + }; for (int i = 0; i < params.size(); ++i) { test->Reset(); diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 1cf98d135cb..4d6f8fa1b31 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -86,6 +86,7 @@ void TRTOptimizationPass::PrintDebugInfo(grappler::Cluster* cluster, string offset2 = StrCat(offset, offset); string offset3 = StrCat(offset2, offset); string offset4 = StrCat(offset2, offset2); + if (cluster) { LOG(INFO) << offset << "type = " << cluster->type(); LOG(INFO) << offset << "num warmup steps = " << cluster->NumWarmupSteps(); @@ -132,7 +133,15 @@ void TRTOptimizationPass::PrintDebugInfo(grappler::Cluster* cluster, } } } + + if (cluster->GetDeviceSet()) { + for (const auto dev : cluster->GetDeviceSet()->devices()) { + LOG(INFO) << "Device name= " << dev->name() << "Pased name= " + << DeviceNameUtils::ParsedNameToString(dev->parsed_name()); + } + } } + LOG(INFO) << "item: " << item.id; if (!item.feed.empty()) { LOG(INFO) << offset << "Feeds :"; @@ -171,13 +180,6 @@ void TRTOptimizationPass::PrintDebugInfo(grappler::Cluster* cluster, } else { LOG(INFO) << offset << "No keep ops"; } - for (const auto dev : cluster->GetDeviceSet()->devices()) { - const auto& pname = dev->parsed_name(); - LOG(INFO) << "Device name= " << dev->name() - << " parsedname job= " << pname.job << " id= " << pname.id - << " has_id: " << pname.has_id << " has_job: " << pname.has_job - << "has_type: " << pname.has_type << " type =" << pname.type; - } } Status TRTOptimizationPass::Optimize(grappler::Cluster* cluster, diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 1094555a622..58d1c611463 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -643,8 +643,10 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, } // Release any outputs that are allocated, ExecuteNativeSegment will // re-allocate them and fail if they are currently allocated. + // The Tensor pointer in the returned TensorValue must be explicitly + // deleted. for (int i = 0; i < ctx->num_outputs(); i++) { - ctx->release_output(i); + delete ctx->release_output(i).tensor; } ExecuteNativeSegment(ctx, helper); return; diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index e7820ca41fe..1337a733f91 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -711,15 +711,15 @@ Status SegmentGraph(const Graph* tf_graph, std::unordered_set unsupported_ops; int num_unsupported_ops = 0; - // Getting the operations blacklisted for conversion - string tftrt_op_blacklist_str; + // Getting the operations denylisted for conversion + string tftrt_op_denylist_str; TF_CHECK_OK( - ReadStringFromEnvVar("TF_TRT_OP_BLACKLIST", "", &tftrt_op_blacklist_str)); + ReadStringFromEnvVar("TF_TRT_OP_DENYLIST", "", &tftrt_op_denylist_str)); - auto tftrt_op_blacklist = gtl::FlatSet{}; // non-absl ok + auto tftrt_op_denylist = gtl::FlatSet{}; // non-absl ok - for (const auto& x : str_util::Split(tftrt_op_blacklist_str, ",")) { - tftrt_op_blacklist.insert(x); + for (const auto& x : str_util::Split(tftrt_op_denylist_str, ",")) { + tftrt_op_denylist.insert(x); } // Parsing each node of the graph @@ -761,13 +761,13 @@ Status SegmentGraph(const Graph* tf_graph, const Status status = candidate_fn(node->tf_node()); if (!status.ok()) { exclude_node(status.error_message()); - } else if (tftrt_op_blacklist.count(node->tf_node()->type_string())) { + } else if (tftrt_op_denylist.count(node->tf_node()->type_string())) { // WARNING verbosity since the user explicitly requests this behavior. LOG_WARNING_WITH_PREFIX - << "Blacklisted as TF-TRT candidate, " + << "Denylisted as TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " << "(Op name: " << node->name() << ")"; - exclude_node("Blacklisted with the env var TF_TRT_OP_BLACKLIST"); + exclude_node("Denylisted with the env var TF_TRT_OP_DENYLIST"); } else { VLOG(2) << "Accepted as a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " @@ -1031,7 +1031,8 @@ Status SegmentGraph(const Graph* tf_graph, }); // Don't use segments whose number of effective nodes is small. - if (num_effective_nodes < options.minimum_segment_size) { + if (num_effective_nodes == 0 || + num_effective_nodes < options.minimum_segment_size) { VLOG(1) << "Segment " << segments->size() << " has only " << num_effective_nodes << " effective nodes, dropping"; continue; diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc index d4f3a524577..a73877bc3cc 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc @@ -74,7 +74,7 @@ void* TRTDeviceAllocator::allocate(uint64_t size, uint64_t alignment, // algorithm uses too much memory. If we don't fail immediately building the // engine can be *very* slow with TensorRT7 when GPU memory is limited. AllocationAttributes attributes; - attributes.no_retry_on_failure = true; + attributes.retry_on_failure = false; void* mem = allocator_->AllocateRaw(alignment, total_size, attributes); if (!mem) return nullptr; diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc index 70a0a9a7b65..2f31865751f 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -35,14 +36,16 @@ void TrtShapeOptimizationProfile::InitProfiles() { << "for each input (min=opt=max)."; } for (auto& shape_vec : input_shapes_) { - std::vector dimvec; - for (auto& shape : shape_vec) { - dimvec.push_back(TensorShapeToTrtDims(shape, false)); + if (!shape_vec.empty()) { + std::vector dimvec(shape_vec.size()); + absl::c_transform(shape_vec, dimvec.begin(), [](TensorShape shape) { + return TensorShapeToTrtDims(shape, false); + }); + // Set min=opt=max. + OptimizationProfileConfig profConfig{dimvec, dimvec, dimvec}; + profiles_.push_back(std::move(profConfig)); + VLOG(1) << "Created profile " << profiles_.back().DebugString(); } - // We set min=opt=max. - OptimizationProfileConfig profConfig{dimvec, dimvec, dimvec}; - profiles_.push_back(std::move(profConfig)); - VLOG(1) << "Created profile " << profiles_.back().DebugString(); } } diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index cac72925dfd..1e57c11b2cf 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -27,6 +27,7 @@ package_group( "//tensorflow/compiler/mlir/...", "//tensorflow/compiler/tests/...", "//tensorflow/compiler/tf2xla/...", + "//tensorflow/core/tpu/...", "//tensorflow/python/compiler/...", ], ) @@ -49,6 +50,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":xla_compiler", + ":xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -144,6 +146,7 @@ cc_library( ":tf2xla_proto_cc", ":tf2xla_util", ":xla_compiler", + ":xla_op_registry", "//tensorflow/compiler/aot:aot_only_var_handle_op", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/client", @@ -315,14 +318,8 @@ cc_library( srcs = [ "const_analysis.cc", "graph_compiler.cc", - "xla_compilation_device.cc", "xla_compiler.cc", - "xla_context.cc", - "xla_expression.cc", - "xla_helpers.cc", "xla_op_kernel.cc", - "xla_op_registry.cc", - "xla_resource.cc", "xla_cpu_backend.cc", ] + if_cuda_is_configured([ "xla_gpu_backend.cc", @@ -332,14 +329,10 @@ cc_library( hdrs = [ "const_analysis.h", "graph_compiler.h", - "xla_compilation_device.h", "xla_compiler.h", - "xla_context.h", - "xla_expression.h", "xla_helpers.h", "xla_op_kernel.h", "xla_op_registry.h", - "xla_resource.h", ], visibility = [":friends"], deps = [ @@ -350,10 +343,18 @@ cc_library( ":sharding_util", ":side_effect_util", ":tf2xla_util", + ":xla_argument", + ":xla_compilation_device", + ":xla_context", + ":xla_expression", + ":xla_helpers", + ":xla_op_registry", + ":xla_resource", "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -369,6 +370,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -387,6 +389,172 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_compilation_device", + srcs = [ + "xla_compilation_device.cc", + ], + hdrs = [ + "xla_compilation_device.h", + ], + deps = [ + ":common", + ":frontend_attributes_util", + ":sharding_util", + ":xla_context", + ":xla_helpers", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:session_options", + "//tensorflow/core/common_runtime:core_cpu_internal", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_context", + srcs = [ + "xla_context.cc", + ], + hdrs = [ + "xla_context.h", + ], + deps = [ + ":common", + ":xla_expression", + ":xla_helpers", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/common_runtime:core_cpu_internal", + "@com_google_absl//absl/types:span", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_op_registry", + srcs = [ + "xla_op_registry.cc", + ], + hdrs = [ + "xla_op_registry.h", + ], + visibility = [":friends"], + deps = [ + ":common", + ":xla_context", + "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/common_runtime:core_cpu_internal", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_expression", + srcs = [ + "xla_expression.cc", + ], + hdrs = [ + "xla_expression.h", + ], + deps = [ + ":common", + ":xla_resource", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_resource", + srcs = [ + "xla_resource.cc", + ], + hdrs = [ + "xla_resource.h", + ], + deps = [ + ":common", + ":sharding_util", + ":xla_helpers", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_helpers", + srcs = [ + "xla_helpers.cc", + ], + hdrs = [ + "xla_helpers.h", + ], + visibility = [":friends"], + deps = [ + ":common", + ":host_compute_metadata_proto_cc", + "//tensorflow/compiler/tf2xla/lib:util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_argument", + srcs = [ + "xla_argument.cc", + ], + hdrs = [ + "xla_argument.h", + ], + deps = [ + ":host_compute_metadata_proto_cc", + ":xla_resource", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework", + "@com_google_absl//absl/types:span", + ], + alwayslink = 1, +) + cc_library( name = "common", srcs = [ @@ -563,6 +731,8 @@ tf_cc_test( ":common", ":side_effect_util", ":xla_compiler", + ":xla_expression", + ":xla_resource", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:functional_ops", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 1da34266460..694aa342aac 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -74,8 +74,7 @@ Status CondConstInputIndices( *(fbody->graph), &compile_time_const_arg_indices, /*compile_time_const_nodes=*/nullptr, flib_runtime)); } - for (int i = 0, iter_limit = compile_time_const_arg_indices.size(); - i < iter_limit; i++) { + for (int i = 0, end = compile_time_const_arg_indices.size(); i < end; i++) { if (compile_time_const_arg_indices[i]) { // The 0th input is the pred or branch index, which is not passed to the // branches. So the i'th input of a branch function corresponds to the @@ -141,7 +140,7 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, GetFunctionBody(flib_runtime, node, "else_branch", &felse)); return CondConstInputIndices({fthen, felse}, const_input_idxs, flib_runtime); - } else if (node.op() == "Case") { + } else if (node.op() == "Case" || node.op() == "StatelessCase") { std::vector branch_bodies; TF_RETURN_IF_ERROR( GetFunctionBodies(flib_runtime, node, "branches", &branch_bodies)); diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 459b2814c0d..54abccb4cfc 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -224,8 +224,8 @@ string DebugString(const CondArgNodes& nodes) { } StateMap::CondId StateMap::LookupCondId(const Node* node) const { - if (node->id() < node_to_condid_map_.size()) - return node_to_condid_map_[node->id()]; + const int64 map_size = node_to_condid_map_.size(); + if (node->id() < map_size) return node_to_condid_map_[node->id()]; return added_node_condid_mapping_.at(node->id()); } @@ -235,15 +235,16 @@ StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) { } void StateMap::ResetCondId(const Node* node, StateMap::CondId id) { - if (node->id() < node_to_condid_map_.size()) + const int64 map_size = node_to_condid_map_.size(); + if (node->id() < map_size) node_to_condid_map_[node->id()] = id; else added_node_condid_mapping_[node->id()] = id; } StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const { - if (node->id() < node_to_ancestorid_map_.size()) - return node_to_ancestorid_map_[node->id()]; + const int64 map_size = node_to_ancestorid_map_.size(); + if (node->id() < map_size) return node_to_ancestorid_map_[node->id()]; return added_node_ancestorid_mapping_.at(node->id()); } @@ -254,7 +255,8 @@ StateMap::AncestorId StateMap::GetAncestorId( } void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) { - if (node->id() < node_to_ancestorid_map_.size()) + const int64 map_size = node_to_ancestorid_map_.size(); + if (node->id() < map_size) node_to_ancestorid_map_[node->id()] = id; else added_node_ancestorid_mapping_[node->id()] = id; diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index cea4973f42b..dce5efe5557 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -130,7 +130,7 @@ Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame, std::vector squash_src_outputs(graph.num_node_ids(), false); // Build one _Arg node for each Enter node. - for (int i = 0; i < frame->args.size(); ++i) { + for (int i = 0, end = frame->args.size(); i < end; ++i) { const WhileLoopArg& arg = frame->args[i]; TF_ASSIGN_OR_RETURN(Node * arg_node, @@ -170,7 +170,7 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame, std::vector next_iterations; next_iterations.reserve(frame->args.size()); arg_types->reserve(frame->args.size()); - for (int i = 0; i < frame->args.size(); ++i) { + for (int i = 0, end = frame->args.size(); i < end; ++i) { const WhileLoopArg& arg = frame->args[i]; DataType dtype = arg.enter->input_type(0); @@ -235,7 +235,7 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, } else { std::vector edges(arg.enter->out_edges().begin(), arg.enter->out_edges().end()); - for (int i = 0; i < edges.size(); ++i) { + for (int i = 0, end = edges.size(); i < end; ++i) { if (edges[i]->IsControlEdge() && edges[i]->dst()->IsSink()) { continue; } @@ -447,7 +447,7 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, } } std::vector inputs; - for (int i = 0; i < frame->args.size(); ++i) { + for (int i = 0, end = frame->args.size(); i < end; ++i) { const WhileLoopArg& arg = frame->args[i]; const Edge* in_edge; TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); @@ -463,7 +463,7 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, TF_ASSIGN_OR_RETURN(Node * while_node, AddNodeDefToGraph(while_def, graph)); // Copies edges to the Enter nodes and from the Exit nodes onto the While. - for (int i = 0; i < frame->args.size(); ++i) { + for (int i = 0, end = frame->args.size(); i < end; ++i) { const WhileLoopArg& arg = frame->args[i]; const Edge* in_edge; TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 5f6dcad5538..30a7e94775b 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -65,7 +65,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, /*compile_time_const_nodes=*/nullptr, ctx->function_library())); args->resize(expressions.size()); - for (int i = 0, iter_limit = args->size(); i < iter_limit; ++i) { + for (int i = 0, end = args->size(); i < end; ++i) { XlaCompiler::Argument& arg = (*args)[i]; arg.type = ctx->input_type(i); arg.shape = ctx->InputShape(i); @@ -269,7 +269,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RET_CHECK(arguments.size() == expressions.size()); std::vector handles; - for (int64 i = 0, iter_limit = expressions.size(); i < iter_limit; ++i) { + for (int64 i = 0, end = expressions.size(); i < end; ++i) { if (arguments[i].kind == XlaCompiler::Argument::kConstant) { continue; } @@ -313,8 +313,7 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, } } - for (int64 i = 0, iter_limit = result.resource_updates.size(); i < iter_limit; - i++) { + for (int64 i = 0, end = result.resource_updates.size(); i < end; i++) { if (result.resource_updates[i].modified) { XlaResource* resource = expressions[result.resource_updates[i].input_index]->resource(); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index e072225566d..26051c98cb7 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -145,7 +145,12 @@ tf_kernel_library( "//tensorflow/compiler/jit:xla_activity_listener", "//tensorflow/compiler/jit:xla_activity_proto_cc", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:data_format", "//tensorflow/compiler/tf2xla/lib:random", @@ -223,6 +228,8 @@ cc_library( deps = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -230,7 +237,7 @@ cc_library( "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/core:framework", - "//tensorflow/core:framework_bounds_check", + "//tensorflow/core/framework:bounds_check", "//tensorflow/core/kernels:conv_grad_shape_utils", "@com_google_absl//absl/types:span", ], @@ -276,6 +283,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", @@ -296,6 +305,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/client:xla_builder", @@ -314,6 +325,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/client:xla_builder", @@ -333,6 +346,7 @@ tf_kernel_library( ], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index fbd54f1ef39..7a3d87c101c 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -160,17 +160,15 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler* compiler = ctx->compiler(); std::vector branch_results(num_branches); - std::vector branch_results_p(num_branches); for (int j = 0; j < num_branches; ++j) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, branches[j], arguments, &branch_results[j])); - branch_results_p[j] = &branch_results[j]; } bool has_tensor_array_gradients = false; - for (XlaCompiler::CompilationResult* result : branch_results_p) { - for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { + for (XlaCompiler::CompilationResult& result : branch_results) { + for (const XlaCompiler::ResourceUpdate& update : result.resource_updates) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index + 1, &resource)); @@ -373,5 +371,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { REGISTER_XLA_OP(Name("Case").AllowResourceTypes().AllowVariantTypes(), XlaCaseOp); +REGISTER_XLA_OP(Name("StatelessCase").AllowResourceTypes().AllowVariantTypes(), + XlaCaseOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 09c97de13eb..d0f24b5f561 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -186,9 +186,11 @@ class ConcatOffsetOp : public XlaOpKernel { const int32 inp0_element = inp0_dims[j]; const int32 inp_element = inp_dims[j]; OP_REQUIRES(ctx, inp0_element == inp_element, - errors::InvalidArgument("input[", i, ",", j, - "] mismatch: ", inp0_element, - " vs. ", inp_element)); + errors::InvalidArgument( + "All dimensions except ", axis, " must match. Input ", + i, " has shape [", absl::StrJoin(inp_dims, " "), + "] and doesn't match input 0 with shape [", + absl::StrJoin(inp0_dims, " "), "].")); out_vec(j) = 0; } } diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index e0bc2ba5052..d29644dd0de 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -44,7 +44,7 @@ namespace tensorflow { namespace { // Returns the expanded size of a filter used for depthwise convolution. -// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. +// If `shape` is [H, W, ..., M, N] returns [H, W, ..., 1, M*N]. xla::Shape GroupedFilterShapeForDepthwiseConvolution( const xla::Shape& filter_shape) { int64 input_feature_dim = filter_shape.dimensions_size() - 2; @@ -52,7 +52,7 @@ xla::Shape GroupedFilterShapeForDepthwiseConvolution( int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim); int64 input_feature = filter_shape.dimensions(input_feature_dim); - // Create a [H, W, ..., 1, N*M] reshape of the filter. + // Create a [H, W, ..., 1, M*N] reshape of the filter. xla::Shape grouped_filter_shape = filter_shape; grouped_filter_shape.set_dimensions(input_feature_dim, 1); grouped_filter_shape.set_dimensions(output_feature_dim, @@ -203,6 +203,10 @@ xla::StatusOr ConvOpAttrs::Create(int num_spatial_dims, return errors::InvalidArgument("Invalid data format: ", data_format); } + TF_RETURN_IF_ERROR(CheckValidPadding(attrs.padding, attrs.explicit_paddings, + /*num_dims=*/num_spatial_dims + 2, + attrs.data_format)); + return attrs; } diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 2a059f78526..3a88fcf4879 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -47,6 +47,122 @@ XlaIfOp::XlaIfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } } +// Populates tensor array gradients for compiled branches, returns whether the +// set of found tensor array gradients is non-empty. +static xla::StatusOr PopulateTensorArrayGradients( + XlaOpKernelContext* ctx, xla::XlaBuilder* b, + absl::Span arguments, + XlaCompiler::CompilationResult* then_result, + XlaCompiler::CompilationResult* else_result) { + bool has_tensor_array_gradients = false; + for (XlaCompiler::CompilationResult* result : {then_result, else_result}) { + for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { + XlaResource* resource; + TF_RETURN_IF_ERROR( + ctx->GetResourceInput(update.input_index + 1, &resource)); + XlaCompiler::Argument& arg = arguments[update.input_index]; + + // Add any TensorArray gradients touched by the then/else computation to + // the enclosing graph. + for (const string& grad_source : update.tensor_array_gradients_accessed) { + VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " + << grad_source; + XlaResource* gradient; + TF_RETURN_IF_ERROR(resource->GetOrCreateTensorArrayGradient( + grad_source, b, &gradient)); + } + // Add all of the TensorArray gradients to the argument. For simplicity, + // we always pass all known gradients. + for (const auto& gradient : resource->tensor_array_gradients()) { + arg.tensor_array_gradients.insert(gradient.first); + } + if (!resource->tensor_array_gradients().empty()) + has_tensor_array_gradients = true; + } + } + return has_tensor_array_gradients; +} + +// Checks that shapes matches on both sides of the conditional. +static Status ValidateShapes( + XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& then_result, + const XlaCompiler::CompilationResult& else_result) { + // Check that both branches have identical input shapes. + if (then_result.xla_input_shapes.size() != 1) { + return errors::FailedPrecondition("Expected one input shape"); + } + + xla::Shape then_input_shape = then_result.xla_input_shapes[0]; + if (!then_input_shape.IsTuple()) { + return errors::FailedPrecondition("Expected tuple shape"); + } + + if (else_result.xla_input_shapes.size() != 1) { + return errors::FailedPrecondition("Expected one input shape"); + } + xla::Shape else_input_shape = else_result.xla_input_shapes[0]; + if (!else_input_shape.IsTuple()) { + return errors::FailedPrecondition("Expected tuple shape"); + } + if (!xla::ShapeUtil::Compatible(then_input_shape, else_input_shape)) { + return errors::InvalidArgument( + "Input shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_input_shape), " vs. ", + xla::ShapeUtil::HumanString(else_input_shape)); + } + + // Check that both branches have identical output shapes. + if (!xla::ShapeUtil::Compatible(then_result.xla_output_shape, + else_result.xla_output_shape)) { + return errors::InvalidArgument( + "Output shapes of then and else branches do not match: ", + xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ", + xla::ShapeUtil::HumanString(else_result.xla_output_shape)); + } + + // Check that both branches have same TensorList output indices. + for (int output_index = 0; output_index < then_result.outputs.size(); + output_index++) { + bool is_tensor_list_in_then_branch = + then_result.outputs[output_index].is_tensor_list; + bool is_tensor_list_in_else_branch = + else_result.outputs[output_index].is_tensor_list; + if (is_tensor_list_in_then_branch != is_tensor_list_in_else_branch) { + return errors::FailedPrecondition( + "Output #", output_index, " is ", + (is_tensor_list_in_then_branch ? "" : "not"), + " a TensorList in then branch, but is ", + (is_tensor_list_in_else_branch ? "" : "not"), + " a TensorList in else branch"); + } + } + + VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape); + VLOG(2) << "Output shape: " + << xla::ShapeUtil::HumanString(then_result.xla_output_shape); + + // We set return_updated_values_for_all_resources=true and we pass the same + // arguments to both computations, so the resource update count must match. + if (then_result.resource_updates.size() != + else_result.resource_updates.size()) { + return errors::FailedPrecondition( + "Different number of resources in then and else branch"); + } + + for (int i = 0; i < then_result.resource_updates.size(); ++i) { + const auto& lhs = then_result.resource_updates[i]; + const auto& rhs = else_result.resource_updates[i]; + bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape && + lhs.tensor_array_gradients_accessed == + rhs.tensor_array_gradients_accessed; + if (!equal) { + return errors::FailedPrecondition( + "Mismatch in resource of then and else branch for resource ", i); + } + } + return Status::OK(); +} + // TODO(b/35949885): There is duplication here with the handling of the // while_op. Refactor the common code out/rework. void XlaIfOp::Compile(XlaOpKernelContext* ctx) { @@ -137,35 +253,12 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, arguments, &else_result)); - bool has_tensor_array_gradients = false; - for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { - for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { - XlaResource* resource; - OP_REQUIRES_OK(ctx, - ctx->GetResourceInput(update.input_index + 1, &resource)); - XlaCompiler::Argument& arg = arguments[update.input_index]; - - // Add any TensorArray gradients touched by the then/else computation to - // the enclosing graph. - for (const string& grad_source : update.tensor_array_gradients_accessed) { - VLOG(5) << "TensorArray " << resource->name() << " accessed gradient " - << grad_source; - XlaResource* gradient; - OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( - grad_source, b, &gradient)); - } - // Add all of the TensorArray gradients to the argument. For simplicity, - // we always pass all known gradients. - for (const auto& gradient : resource->tensor_array_gradients()) { - arg.tensor_array_gradients.insert(gradient.first); - } - if (!resource->tensor_array_gradients().empty()) - has_tensor_array_gradients = true; - } - } + xla::StatusOr has_tensor_array_gradients = PopulateTensorArrayGradients( + ctx, b, absl::MakeSpan(arguments), &then_result, &else_result); + OP_REQUIRES_OK(ctx, has_tensor_array_gradients.status()); // Recompile the functions to update the argument shapes for tensor arrays. - if (has_tensor_array_gradients) { + if (*has_tensor_array_gradients) { then_result = {}; OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, arguments, &then_result)); @@ -174,72 +267,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arguments, &else_result)); } - // Check that both branches have identical input shapes. - OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, - errors::FailedPrecondition("Expected one input shape")); - xla::Shape then_input_shape = then_result.xla_input_shapes[0]; - OP_REQUIRES(ctx, then_input_shape.IsTuple(), - errors::FailedPrecondition("Expected tuple shape")); - OP_REQUIRES(ctx, else_result.xla_input_shapes.size() == 1, - errors::FailedPrecondition("Expected one input shape")); - xla::Shape else_input_shape = else_result.xla_input_shapes[0]; - OP_REQUIRES(ctx, else_input_shape.IsTuple(), - errors::FailedPrecondition("Expected tuple shape")); - OP_REQUIRES(ctx, - xla::ShapeUtil::Compatible(then_input_shape, else_input_shape), - errors::InvalidArgument( - "Input shapes of then and else branches do not match: ", - xla::ShapeUtil::HumanString(then_input_shape), " vs. ", - xla::ShapeUtil::HumanString(else_input_shape))); - - // Check that both branches have identical output shapes. - OP_REQUIRES( - ctx, - xla::ShapeUtil::Compatible(then_result.xla_output_shape, - else_result.xla_output_shape), - errors::InvalidArgument( - "Output shapes of then and else branches do not match: ", - xla::ShapeUtil::HumanString(then_result.xla_output_shape), " vs. ", - xla::ShapeUtil::HumanString(else_result.xla_output_shape))); - - // Check that both branches have same TensorList output indices. - for (int output_index = 0; output_index < then_result.outputs.size(); - output_index++) { - bool is_tensor_list_in_then_branch = - then_result.outputs[output_index].is_tensor_list; - bool is_tensor_list_in_else_branch = - else_result.outputs[output_index].is_tensor_list; - OP_REQUIRES( - ctx, is_tensor_list_in_then_branch == is_tensor_list_in_else_branch, - errors::FailedPrecondition("Output #", output_index, " is ", - (is_tensor_list_in_then_branch ? "" : "not"), - " a TensorList in then branch, but is ", - (is_tensor_list_in_else_branch ? "" : "not"), - " a TensorList in else branch")); - } - - VLOG(2) << "Input shape: " << xla::ShapeUtil::HumanString(then_input_shape); - VLOG(2) << "Output shape: " - << xla::ShapeUtil::HumanString(then_result.xla_output_shape); - - // We set return_updated_values_for_all_resources=true and we pass the same - // arguments to both computations, so the resource update count must match. - OP_REQUIRES(ctx, - then_result.resource_updates.size() == - else_result.resource_updates.size(), - errors::FailedPrecondition( - "Different number of resources in then and else branch")); - for (int i = 0; i < then_result.resource_updates.size(); ++i) { - const auto& lhs = then_result.resource_updates[i]; - const auto& rhs = else_result.resource_updates[i]; - bool equal = lhs.input_index == rhs.input_index && lhs.shape == rhs.shape && - lhs.tensor_array_gradients_accessed == - rhs.tensor_array_gradients_accessed; - OP_REQUIRES( - ctx, equal, - errors::FailedPrecondition( - "Mismatch in resource of then and else branch for resource ", i)); - } + OP_REQUIRES_OK(ctx, ValidateShapes(ctx, then_result, else_result)); int num_inputs = then_result.input_mapping.size(); std::vector inputs(num_inputs); @@ -263,22 +291,18 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { } } - auto input_tuple = xla::Tuple(b, inputs); + xla::XlaOp input_tuple = xla::Tuple(b, inputs); xla::XlaOp outputs = xla::Conditional(ctx->Input(0), input_tuple, *then_result.computation, input_tuple, *else_result.computation); + // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { xla::XlaOp output_handle = xla::GetTupleElement(outputs, i); if (VLOG_IS_ON(2)) { - LOG(INFO) << "Setting output " << i; - auto shape_or = b->GetShape(output_handle); - if (shape_or.ok()) { - LOG(INFO) << "Shape for output " << i << ": " - << xla::ShapeUtil::HumanString(shape_or.ValueOrDie()); - } else { - LOG(INFO) << "Shape unknown for output " << i; - } + xla::StatusOr shape = b->GetShape(output_handle); + VLOG(2) << "Setting output " << i << " with shape " + << (shape.ok() ? shape->ToString() : ""); } // We have checked that both branches have same TensorList output indices. if (then_result.outputs[i].is_tensor_list) { @@ -287,6 +311,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { ctx->SetOutput(i, output_handle); } } + if (has_token_input_output_) { // Set token output for this "If" op. Token output is the last output of // XLA computation, which comes after all "normal" TF outputs and resource diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 31637d9d8a0..df6d9b475dc 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -71,7 +71,7 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { if (is_gpu_) { output = xla::ArgMinTwoPass(input, index_xla_type, axis); } else { - output = xla::ArgMin(input, index_xla_type, axis); + output = xla::ArgMin(input, index_xla_type, axis, /*stable=*/true); } } else { if (is_gpu_) { diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index 57e961917cc..c8da75157fc 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -243,8 +243,9 @@ class MatrixDiagOp : public XlaOpKernel { errors::InvalidArgument("MatrixDiag op must have at least one input")); const TensorShape diag_shape = context->InputShape(0); OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape), - errors::InvalidArgument("Expected >= 1 dims, got shape ", - diag_shape.DebugString())); + errors::InvalidArgument( + "diagonal must be at least 1-dim, received shape: ", + diag_shape.DebugString())); const DataType dtype = context->expected_output_dtype(0); const xla::XlaOp zero = XlaHelpers::Zero(context->builder(), dtype); diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index bf9a9150ea6..a85ba547179 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -109,27 +109,33 @@ class ReshapeOp : public XlaOpKernel { VLOG(2) << "Reshape from " << input_shape.DebugString() << " to " << shape.DebugString() << ", unknown_index=" << unknown_index; - shape_input.clear(); - // Run get input again, this time with dynamic dimension represented as - // "-1" - ctx->set_dynamic_dimension_is_minus_one(true); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &shape_input)); - int dynamic_dimension = -1; - - for (int d = 0; d < num_dims; ++d) { - const int32 size = shape_input[d]; - if (size == -1) { - if (dynamic_dimension == -1) { + if (ctx->InputXlaShape(0)->is_dynamic()) { + std::vector dynamic_dims; + OP_REQUIRES_OK(ctx, + ctx->ResolveInputDynamismIntoPredVector(1, &dynamic_dims)); + for (int d = 0; d < num_dims; ++d) { + const bool dim_is_dynamic = dynamic_dims[d]; + if (dim_is_dynamic) { dynamic_dimension = d; - } else { - if (unknown_index != d) { - dynamic_dimension = d; - } } } - } + // When reshaping from dynamic dimension, unkwown index is considered + // dynamic. E.g., + // [<=10] + // | + // Reshape + // | + // [2, -1] + // The second dimension is dynamic. + if (dynamic_dimension == -1) { + dynamic_dimension = unknown_index; + } + VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to " + << xla::VectorString(shape.dim_sizes()) + << ", dynamic_dim=" << dynamic_dimension; + } // Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference // in XLA to know which output dimension is dynamic. ctx->SetOutput(0, xla::ReshapeWithInferredDimension( diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc index 1047580264b..da268fe283c 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc @@ -30,8 +30,15 @@ class ShardingOp : public XlaOpKernel { ~ShardingOp() override = default; void Compile(XlaOpKernelContext* ctx) override { - xla::XlaOp input = ctx->Input(0); - auto shape_or = ctx->InputXlaShape(0); + xla::XlaOp input; + { + // The builder might create a broadcast from a constant, so we clear + // sharding for the input. + xla::XlaScopedShardingAssignment no_sharding(ctx->builder(), + absl::nullopt); + input = ctx->Input(0); + } + auto shape_or = ctx->builder()->GetShape(input); OP_REQUIRES_OK(ctx, shape_or.status()); ctx->SetOutput( diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 7a0e240400b..dbaa84c223d 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -105,6 +105,10 @@ class SplitVOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); const TensorShape index_shape = ctx->InputShape(2); + OP_REQUIRES(ctx, index_shape.num_elements() == 1, + errors::InvalidArgument( + "split_dim_tensor must have exactly one element.")); + int64 split_dim_orig; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &split_dim_orig)); int64 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index aa71e4d4364..0e367e10ec4 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -504,7 +504,9 @@ Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, xla::XlaOp list_part = xla::GetTupleElement(list, 0); xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); - for (int64 i = 0; i < buffer_shape.dimensions_size(); ++i) { + // Propagate dynamic dimensions from buffer to the sliced buffer, except for + // leading dimension (which is always static 1). + for (int64 i = 1; i < buffer_shape.dimensions_size(); ++i) { if (buffer_shape.is_dynamic_dimension(i)) { auto buffer = xla::GetTupleElement(list, 0); auto gds = xla::GetDimensionSize(buffer, i); diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 65569576d41..9a4722d149e 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" @@ -110,11 +111,11 @@ REGISTER_XLA_OP(Name("ConjugateTranspose").CompileTimeConstantInput("perm"), // InvertPermutation frequently forms part of the gradient of Transpose. // -// inv = InvertPermutationOp(T p) takes a permutation of +// inv = InvertPermutationOp(p) takes a permutation of // integers 0, 1, ..., n - 1 and returns the inverted // permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n). // -// REQUIRES: input is a vector of int32. +// REQUIRES: input is a vector of int32 or int64. // REQUIRES: input is a permutation of 0, 1, ..., n-1. class InvertPermutationOp : public XlaOpKernel { @@ -122,11 +123,32 @@ class InvertPermutationOp : public XlaOpKernel { explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + DataType dtype = ctx->expected_output_dtype(0); + Status status; + switch (dtype) { + case DT_INT32: + InvertPermutation(ctx); + break; + case DT_INT64: + InvertPermutation(ctx); + break; + default: + // This should never happen since we restrict this kernel to only match + // inputs with supported Tensor datatype. + OP_REQUIRES_OK(ctx, errors::InvalidArgument( + "InvertPermutation expects x as either ", + "int32 or int64, not ", DataTypeString(dtype))); + } + } + + template + void InvertPermutation(XlaOpKernelContext* ctx) { OP_REQUIRES(ctx, FastBoundsCheck(ctx->InputShape(0).num_elements(), - std::numeric_limits::max()), - errors::InvalidArgument("permutation of nonnegative int32s " - "must have <= int32 max elements")); + std::numeric_limits::max()), + errors::InvalidArgument( + "permutation of nonnegative integers must have <= ", + std::numeric_limits::max(), " elements")); auto e = ctx->InputExpression(0); auto tensor_or_status = e.ResolveConstant(ctx->compiler()->client()); @@ -142,7 +164,7 @@ class InvertPermutationOp : public XlaOpKernel { int size = perm.size(); - std::vector output(size); + std::vector output(size); std::fill_n(output.data(), size, -1); for (int i = 0; i < size; ++i) { const int64 d = perm[i]; @@ -153,11 +175,13 @@ class InvertPermutationOp : public XlaOpKernel { output[d] = i; } - ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); + ctx->SetOutput(0, xla::ConstantR1(ctx->builder(), output)); } else { auto indices = ctx->Input(0); - int size = ctx->InputShape(0).num_elements(); - auto iota = xla::Iota(ctx->builder(), xla::S32, size); + T size = ctx->InputShape(0).num_elements(); + auto iota = + xla::Iota(ctx->builder(), + xla::primitive_util::NativeToPrimitiveType(), size); auto result = XlaScatter(iota, iota, indices, /*indices_are_vectors=*/false, /*combiner=*/{}, ctx->builder()); @@ -167,8 +191,9 @@ class InvertPermutationOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32), - InvertPermutationOp); +REGISTER_XLA_OP( + Name("InvertPermutation").TypeConstraint("T", {DT_INT32, DT_INT64}), + InvertPermutationOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 6d4393ee006..6fe6b164951 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -76,6 +77,8 @@ XLAJIT_MAKE_UNARY(Log1p, xla::Log1p(x)); XLAJIT_MAKE_UNARY(Invert, xla::Not(x)); XLAJIT_MAKE_UNARY(LogicalNot, xla::Not(x)); +XLAJIT_MAKE_UNARY(PopulationCount, + xla::ConvertElementType(xla::PopulationCount(x), xla::U8)); XLAJIT_MAKE_UNARY(Neg, -x); XLAJIT_MAKE_UNARY(Rint, xla::RoundToEven(x)); diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index f0bd97c85eb..531679d3905 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -38,6 +38,7 @@ cc_library( hdrs = ["random.h"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:xla_builder", diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index 2ab86c78e44..e5913a8bbf3 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -66,7 +66,7 @@ xla::StatusOr Expand(xla::XlaOp input, int64 dim) { // Move the newly created dimension to the end with a transpose. std::vector permutation; - for (int64 i = 0, iter_limit = expanded_shape.size(); i != iter_limit; ++i) { + for (int64 i = 0, end = expanded_shape.size(); i != end; ++i) { permutation.push_back(i); if (i == dim) { ++i; diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 42a95bbb9f8..74ca16bbaeb 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -72,7 +72,7 @@ Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, buf_ptrs.reserve(host_tensors.size()); std::vector tensor_shapes(host_tensors.size()); - for (int i = 0, iter_limit = host_tensors.size(); i < iter_limit; i++) { + for (int i = 0, end = host_tensors.size(); i < end; i++) { // Validate runtime shapes and fail if it doesn't match the contract. const Tensor* tensor = &host_tensors[i]; buf_ptrs.emplace_back(static_cast(DMAHelper::base(tensor))); diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 862da1f3f95..f4b9e9654d2 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -441,7 +441,8 @@ REGISTER_OP("XlaReduce") auto dim_in_range = [rank](int64 dim) { return dim >= 0 && dim < rank; }; - if (rank < dimensions_to_reduce.size() || + const int dimensions_to_reduce_size = dimensions_to_reduce.size(); + if (rank < dimensions_to_reduce_size || dims_set.size() != dimensions_to_reduce.size() || !absl::c_all_of(dimensions_to_reduce, dim_in_range)) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 0ebca2d546f..846dafa2570 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -28,6 +28,7 @@ from __future__ import division from __future__ import print_function from tensorflow.compiler.tf2xla.ops import gen_xla_ops +from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -415,8 +416,11 @@ sharding = gen_xla_ops.xla_sharding @ops.RegisterGradient("XlaSharding") def _sharding_grad(op, grad): - del op # Unused - return [grad] + grad_sharding = gen_xla_ops.xla_sharding(grad) + # pylint: disable=protected-access + grad_sharding.op._set_attr( + "_XlaSharding", attr_value_pb2.AttrValue(s=op.get_attr("_XlaSharding"))) + return [grad_sharding] spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index b6f8928f31e..ed7927a9999 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -41,7 +41,7 @@ std::vector ShuffleInputDataTypeAttribute( const std::vector& in_types, const std::vector& index_mapping) { std::vector result(index_mapping.size()); - for (int i = 0; i < in_types.size(); i++) { + for (int i = 0, end = in_types.size(); i < end; i++) { result[index_mapping.at(i)] = in_types[i]; } return result; @@ -56,7 +56,7 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, bool* need_rewrite, int* resource_input_count, std::vector* index_mapping) { int first_resource_index = -1; - for (int i = 0; i < in_types.size(); i++) { + for (int i = 0, end = in_types.size(); i < end; i++) { DataType type = in_types[i]; if (type == DT_RESOURCE) { first_resource_index = i; @@ -70,7 +70,7 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, } *need_rewrite = false; - for (int i = first_resource_index + 1; i < in_types.size(); i++) { + for (int i = first_resource_index + 1, end = in_types.size(); i < end; i++) { if (in_types[i] != DT_RESOURCE) { *need_rewrite = true; break; @@ -81,7 +81,7 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, } *resource_input_count = 0; - for (int i = 0; i < in_types.size(); i++) { + for (int i = 0, end = in_types.size(); i < end; i++) { DataType type = in_types[i]; if (type == DT_RESOURCE) { ++(*resource_input_count); @@ -90,7 +90,7 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, int non_resource_index = 0, resource_index = in_types.size() - *resource_input_count; index_mapping->resize(in_types.size()); - for (int i = 0; i < in_types.size(); i++) { + for (int i = 0, end = in_types.size(); i < end; i++) { if (in_types[i] != DT_RESOURCE) { (*index_mapping)[i] = non_resource_index; non_resource_index++; @@ -180,7 +180,7 @@ Status CalculateRetvalRearrange( const gtl::InlinedVector& ret_nodes, // non-absl ok std::map* retval_index_mapping, std::map* resource_retval_to_arg) { - for (int i = 0; i < ret_nodes.size(); i++) { + for (int i = 0, end = ret_nodes.size(); i < end; i++) { Node* n = ret_nodes[i]; DataType t; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &t)); @@ -261,7 +261,7 @@ Status RearrangeOutputEdges(Node* n, Graph* g, void RearrangeRetvalNodes( const gtl::InlinedVector& ret_nodes, // non-absl ok Graph* g, const std::map& retval_index_mapping) { - for (int i = 0; i < ret_nodes.size(); i++) { + for (int i = 0, end = ret_nodes.size(); i < end; i++) { Node* n = ret_nodes[i]; auto iter = retval_index_mapping.find(i); if (iter == retval_index_mapping.end()) { @@ -317,7 +317,7 @@ Status MaybeRewriteWhileNode( // lambda resource_var1, resource_var2: [resource_var2, resource_var1], // [resource_var1, resource_var2]) if (attr_name == "body") { - for (int i = 0; i < fbody->ret_nodes.size(); i++) { + for (int i = 0, end = fbody->ret_nodes.size(); i < end; i++) { Node* n = fbody->ret_nodes[i]; DataType dtype; TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "T", &dtype)); @@ -349,7 +349,7 @@ Status MaybeRewriteWhileNode( RearrangeArgNodes(&fbody->arg_nodes, index_mapping); if (attr_name == "body") { - for (int i = 0; i < fbody->ret_nodes.size(); i++) { + for (int i = 0, end = fbody->ret_nodes.size(); i < end; i++) { Node* n = fbody->ret_nodes[i]; int new_index = index_mapping.at(i); if (new_index < types.size() - resource_input_count) { diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 366e8d49228..90585c9d98a 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -80,6 +80,30 @@ xla::StatusOr> ParseShardingFromDevice( return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); } +xla::StatusOr> ParseShardingFromEdgeSource( + const Edge& edge, int num_cores_per_replica) { + if (edge.src() == nullptr) { + return tensorflow::errors::InvalidArgument( + "Null src for ParseShardingFromEdgeSource edge=", edge.DebugString()); + } + TF_ASSIGN_OR_RETURN( + absl::optional sharding, + ParseShardingFromDevice(*edge.src(), num_cores_per_replica)); + if (sharding.has_value() && + sharding.value().type() == xla::OpSharding::TUPLE) { + if (edge.src_output() < 0 || + edge.src_output() >= sharding.value().tuple_shardings_size()) { + return tensorflow::errors::InvalidArgument( + "Tuple index out of bound: edge=", edge.DebugString(), + " sharding=", sharding->DebugString()); + } + absl::optional subsharding = + sharding.value().tuple_shardings(edge.src_output()); + return subsharding; + } + return sharding; +} + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { string device_name = src.assigned_device_name(); if (device_name.empty()) { diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index 196434826f9..07657c656d3 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -43,6 +43,9 @@ xla::StatusOr> ParseShardingFromDevice( xla::StatusOr> ParseShardingFromDevice( const NodeDef& node_def, int num_cores_per_replica); +xla::StatusOr> ParseShardingFromEdgeSource( + const Edge& edge, int num_cores_per_replica); + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); // Get sharding inforamtion from node. diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 0454bbb771a..242a2b04ab9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -87,7 +87,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, *computation = std::move(*result.computation); int num_const_results = 0; - for (int i = 0, iter_limit = result.outputs.size(); i < iter_limit; ++i) { + for (int i = 0, end = result.outputs.size(); i < end; ++i) { // Ending up with const results (i.e. output args) is an error, since it // means that one or more fetches that the user specified will be dropped // from the generated function. It's most likely a configuration error, diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 5229104e674..8863b08b77b 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -143,7 +143,7 @@ Status ReplaceArgUsageWithConstNode( usages.push_back({e->dst()->id(), e->dst_input()}); } - for (int i = 0; i < usages.size(); i++) { + for (int i = 0, end = usages.size(); i < end; i++) { // Make a copy of `usage_node`, and change its input to const node. Node* usage_node = g->FindNodeId(usages[i].dst_node_id); NodeDef replace_def = usage_node->def(); @@ -158,7 +158,7 @@ Status ReplaceArgUsageWithConstNode( // Later entries in `usages` might have `usage_node` as dst node, but // `usage_node` is removed. Replace such entries with `replace_node`. - for (int j = i + 1; j < usages.size(); j++) { + for (int j = i + 1, end = usages.size(); j < end; j++) { if (usages[j].dst_node_id == usages[i].dst_node_id) { usages[j].dst_node_id = replace_node->id(); } diff --git a/tensorflow/compiler/tf2xla/xla_argument.cc b/tensorflow/compiler/tf2xla/xla_argument.cc new file mode 100644 index 00000000000..fe31025386e --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_argument.cc @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_argument.h" + +namespace tensorflow { + +bool XlaArgument::operator==(const XlaArgument& other) const { + if (std::tie(kind, resource_kind, type, name, initialized, max_array_size, + tensor_array_gradients) != + std::tie(other.kind, other.resource_kind, other.type, other.name, + other.initialized, other.max_array_size, + other.tensor_array_gradients)) { + return false; + } + if (absl::holds_alternative(shape)) { + if (!absl::holds_alternative(other.shape)) { + return false; + } + if (!xla::Shape::Equal()(absl::get(shape), + absl::get(other.shape))) { + return false; + } + } else { + if (!absl::holds_alternative(other.shape)) { + return false; + } + if (absl::get(shape) != absl::get(other.shape)) { + return false; + } + } + if (constant_value.shape() != other.constant_value.shape()) { + return false; + } + if (is_same_data_across_replicas != other.is_same_data_across_replicas) { + return false; + } + return constant_value.tensor_data() == other.constant_value.tensor_data(); +} + +} // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h new file mode 100644 index 00000000000..e2cd634e1d5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -0,0 +1,121 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ + +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Describes how to derive the value of each _Arg node in the graph/function +// being compiled. There must be one Argument for each _Arg index. +struct XlaArgument { + enum Kind { + // Default value; not a valid kind. + kInvalid, + + // Argument is a compile-time constant. No associated runtime parameter. + kConstant, + + // Argument is a Variable, TensorArray, or Stack resource. Has an + // associated runtime parameter iff `initialized` is true. + kResource, + + // Argument is a run-time parameter. + kParameter, + + // Argument is an XLA token. + kToken, + + // Argument is a TensorList. + kTensorList, + }; + + Kind kind = kInvalid; + + // The type of the argument. If the argument is a resource, this + // is the type of the variable's value, not DT_RESOURCE. + DataType type = DT_INVALID; + + // The shape of the argument. For: + // * a parameter: the shape of the parameter. We allow setting the xla shape + // if known. This helps avoid conversions to and from TensorShape. + // * a constant: ignored; the shape given by constant_value is used + // instead. + // * an uninitialized resource: ignored. We don't yet know the shape of an + // uninitialized resource (otherwise we would have initialized it!) + // * an initialized variable: the shape of the variable's value. + // * an initialized TensorArray or Stack resource: the shape of an entry in + // the TensorArray/Stack. Note this is the size of a single entry, not the + // XLA data structure that represents the complete stack/array. + absl::variant shape; + + // The value of the argument, if it is a compile-time constant. Must be a + // host-memory tensor. + Tensor constant_value; + + // The name of this argument, used for debugging. + string name; + + // The name of TensorFlow _Arg node, used for debugging. + string node_name; + + // For a kResource, what kind of resource is it? + XlaResource::Kind resource_kind = XlaResource::kInvalid; + + // For a kResource, has this resource been initialized? + bool initialized = false; + + // For a kResource, is this resource on Fast Memory. + bool fast_mem = false; + + // For a TensorArray or Stack resource, what is the array's declared size? + // (Used for lazy initialization.) + int64 max_array_size = -1; + + // TensorArray resource parameters are passed as (array, gradient array 0, + // ..., gradient array k), where the gradient arrays are in the same order + // as `tensor_array_gradients`. + std::set tensor_array_gradients; + + // dynamic dims to arg number map. Empty if no dynamic shapes. + std::map dynamic_dim_to_arg_num_map; + bool is_pad_arg = false; + + // Whether this argument will receive the same data across all replicas. + bool is_same_data_across_replicas = false; + + bool operator==(const XlaArgument& other) const; + + // Returns a human-readable summary of the argument. + string HumanString() const; + + // Returns the dimension sizes for either TensorShape or xla::Shape. + std::vector DimensionSizes() const; + absl::InlinedVector DimensionSizesAsInlinedVector() const; + + // Returns the human-readable string for either TensorShape or xla::Shape. + string ShapeHumanString() const; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 6d92fd97793..635b7170d82 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/rearrange_function_argument.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -52,6 +53,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/util/dump_graph.h" namespace tensorflow { @@ -64,7 +66,7 @@ Status CheckSignature(const DataTypeVector& types, return errors::Internal("Compilation arguments have ", args.size(), " elements while function has ", types.size()); } - for (int i = 0, iter_limit = types.size(); i < iter_limit; ++i) { + for (int i = 0, end = types.size(); i < end; ++i) { // Don't perform type checks on resource variables and tensor // lists (DT_VARIANT) as we have to trick the type system in order to // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor. @@ -168,7 +170,7 @@ Status BuildComputation( int* num_computation_outputs, int* num_nonconst_outputs, std::vector* outputs, std::vector* resource_updates, - xla::Shape* output_shape) { + xla::Shape* output_shape, absl::Span input_mapping) { // Attach a common operator name as metadata. This has no semantic effect — it // merely makes the HLO graph more readable when visualized via TensorBoard, // since TensorBoard forms groups out of operators with similar names. @@ -192,7 +194,7 @@ Status BuildComputation( // replicate sharding is used. The first element is the output index, second // element is the sharding. std::unordered_map retval_index_and_sharding; - for (int i = 0, iter_limit = retvals.size(); i < iter_limit; ++i) { + for (int i = 0, end = retvals.size(); i < end; ++i) { XlaCompiler::OutputDescription& output = (*outputs)[i]; const XlaExpression& retval = retvals[i]; output.type = retval.dtype(); @@ -268,6 +270,11 @@ Status BuildComputation( return a->arg_num() < b->arg_num(); }); + absl::flat_hash_map argument_to_xla_arg; + for (int xla_arg = 0; xla_arg < input_mapping.size(); xla_arg++) { + argument_to_xla_arg[input_mapping[xla_arg]] = xla_arg; + } + std::vector aliases; for (const XlaResource* resource : arg_resources) { DCHECK_LT(resource->arg_num(), args.size()); @@ -290,19 +297,20 @@ Status BuildComputation( update.type = resource->type(); update.shape = resource->shape(); update.modified = modified; + int param_num = use_tuple_arg ? 0 : update.input_index; if (is_entry_computation && arg.resource_kind != XlaResource::kTensorArray && - alias_resource_update) { + alias_resource_update && argument_to_xla_arg.count(param_num)) { // Assuming tuple arg and results are used. xla::ShapeIndex param_index = use_tuple_arg ? xla::ShapeIndex({update.input_index}) : xla::ShapeIndex{}; - int param_number = use_tuple_arg ? 0 : update.input_index; + int xla_param_num = argument_to_xla_arg[param_num]; int64 output_index_num = elems.size(); xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num}); VLOG(3) << "Storing alias: " << output_index.ToString() << ": (" - << param_number << ", " << param_index.ToString() << ")"; - aliases.push_back({output_index, param_number, param_index}); + << xla_param_num << ", " << param_index.ToString() << ")"; + aliases.push_back({output_index, xla_param_num, param_index}); } for (const auto& grad : resource->tensor_array_gradients()) { update.tensor_array_gradients_accessed.insert(grad.first); @@ -356,7 +364,7 @@ Status BuildComputation( xla::Shape shape = xla::ShapeUtil::MakeTupleShape(elem_shapes); // Copy specified sharding from retval_index_and_sharding. std::vector sharding_elems; - for (int i = 0, iter_limit = elems.size(); i < iter_limit; i++) { + for (int i = 0, end = elems.size(); i < end; i++) { const auto& iter = retval_index_and_sharding.find(i); TF_RET_CHECK(iter != retval_index_and_sharding.end()); const xla::OpSharding& sub_op_sharding = iter->second; @@ -416,39 +424,6 @@ Status BuildComputation( } // namespace -bool XlaCompiler::Argument::operator==( - const XlaCompiler::Argument& other) const { - if (std::tie(kind, resource_kind, type, name, initialized, max_array_size, - tensor_array_gradients) != - std::tie(other.kind, other.resource_kind, other.type, other.name, - other.initialized, other.max_array_size, - other.tensor_array_gradients)) { - return false; - } - if (absl::holds_alternative(shape)) { - if (!absl::holds_alternative(other.shape)) { - return false; - } - if (!xla::Shape::Equal()(absl::get(shape), - absl::get(other.shape))) { - return false; - } - } else { - if (!absl::holds_alternative(other.shape)) { - return false; - } - if (absl::get(shape) != absl::get(other.shape)) { - return false; - } - } - if (constant_value.shape() != other.constant_value.shape()) { - return false; - } - if (is_same_data_across_replicas != other.is_same_data_across_replicas) { - return false; - } - return constant_value.tensor_data() == other.constant_value.tensor_data(); -} string XlaCompiler::Argument::HumanString() const { string common; @@ -701,7 +676,7 @@ Status XlaCompiler::CompileFunction( // Set shapes for _Arg nodes. They are useful for constant folding (e.g. an // Xla op requires a compile-time constant input, and that input is shape of // an _Arg node. - for (int i = 0, iter_limit = args.size(); i < iter_limit; i++) { + for (int i = 0, end = args.size(); i < end; i++) { // Skip resource variables and tensor lists. DataType dtype; TF_RETURN_IF_ERROR(GetNodeAttr(fbody->arg_nodes[i]->def(), "T", &dtype)); @@ -753,8 +728,18 @@ Status XlaCompiler::CompileFunction( } VLOG(1) << "===================================================="; - TF_RETURN_IF_ERROR( - CompileGraph(options, function_id, std::move(graph), args, result)); + if (GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) { + VLOG(1) << "Using MLIR bridge"; + GraphDebugInfo debug_info; + TF_RETURN_IF_ERROR(CompileGraphToXlaHlo( + std::move(*graph), {args.data(), args.size()}, + options_.device_type.type_string(), options.use_tuple_arg, + *options_.flib_def, debug_info, options_.shape_representation_fn, + result)); + } else { + TF_RETURN_IF_ERROR( + CompileGraph(options, function_id, std::move(graph), args, result)); + } VLOG(1) << "===================================================="; cache_[{function_id, arg_vector}] = *result; @@ -943,7 +928,7 @@ Status XlaCompiler::BuildArguments( // to the d'th XLA input. Note that the value -1 corresponds to constants, or // other args that don't correspond to an input. std::vector arg_to_inputs(args.size(), -1); - for (int i = 0, iter_limit = input_to_args->size(); i < iter_limit; i++) { + for (int i = 0, end = input_to_args->size(); i < end; i++) { arg_to_inputs[input_to_args->at(i)] = i; } @@ -989,7 +974,7 @@ Status XlaCompiler::BuildArguments( : it->second; } std::vector is_same_across_replicas; - for (int i = 0, iter_limit = input_to_args->size(); i < iter_limit; ++i) { + for (int i = 0, end = input_to_args->size(); i < end; ++i) { // Add an entry to is_same_across_replicas for every leaf buffer. is_same_across_replicas.insert( is_same_across_replicas.end(), @@ -1005,7 +990,7 @@ Status XlaCompiler::BuildArguments( tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } - for (int i = 0, iter_limit = input_to_args->size(); i < iter_limit; ++i) { + for (int i = 0, end = input_to_args->size(); i < end; ++i) { const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); @@ -1024,6 +1009,11 @@ Status XlaCompiler::BuildArguments( xla::XlaScopedShardingAssignment assign_sharding( builder, it == arg_shardings.end() ? absl::optional() : it->second); + auto& arg = args[input_to_args->at(i)]; + + xla::OpMetadata arg_metadata; + arg_metadata.set_op_name(arg.node_name); + builder->SetOneShotOpMetadata(arg_metadata); arg_handles[i] = xla::GetTupleElement(tuple, i); } } else { @@ -1046,7 +1036,7 @@ Status XlaCompiler::BuildArguments( } } - for (int i = 0, iter_limit = input_to_args->size(); i < iter_limit; ++i) { + for (int i = 0, end = input_to_args->size(); i < end; ++i) { const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); @@ -1315,7 +1305,8 @@ Status XlaCompiler::CompileGraph( options.always_return_tuple, options.use_tuple_arg, options.alias_resource_update, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, - &result->resource_updates, &result->xla_output_shape)); + &result->resource_updates, &result->xla_output_shape, + result->input_mapping)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; @@ -1366,7 +1357,7 @@ void SetTransfer(const string& key, absl::Span types, tf2xla::HostTransferMetadata* transfer) { transfer->set_key(key); CHECK(types.size() == shapes.size()); - for (int i = 0, iter_limit = types.size(); i < iter_limit; ++i) { + for (int i = 0, end = types.size(); i < end; ++i) { tf2xla::TensorMetadata* metadata = transfer->add_metadata(); metadata->set_type(types[i]); shapes[i].AsProto(metadata->mutable_shape()); @@ -1482,93 +1473,4 @@ xla::StatusOr XlaCompiler::GetNodeToken(const string& node_name) { return iter->second; } -XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn() { - return [](const TensorShape& shape, DataType dtype, - bool use_fast_memory) -> xla::StatusOr { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); - return xla_shape; - }; -} - -// Rewrites the layout of xla_shape if there is tiled sharding. -Status RewriteLayoutWithShardedShape( - const absl::optional& sharding, bool use_fast_memory, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - xla::Shape* xla_shape) { - if (sharding && !sharding->IsTileMaximal()) { - // After sharding, per core shape might have different layout. For example, - // before sharding, a shape [128, 128] will be assigned default - // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2, - // the sharded shapes will have minor-to-major {0, 1}. - // - // As a result, for sharded shapes, we set their layout to per core shape's - // layout. - // - // TODO(endlessroad): for variable input & update, we might have - // different layouts which will prevent input output aliasing and - // increase memory usage. Investigate such cases. - int64 device = *sharding->tile_assignment().begin(); - std::vector offset = - sharding->TileOffsetForDevice(*xla_shape, device); - std::vector limit = sharding->TileLimitForDevice(*xla_shape, device); - std::vector dimensions(xla_shape->rank()); - for (int64 i = 0; i < xla_shape->rank(); ++i) { - dimensions[i] = limit[i] - offset[i]; - } - xla::Shape per_device_xla_shape = - xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); - TensorShape per_device_tensor_shape; - TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape)); - TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( - xla_shape->element_type())); - TF_ASSIGN_OR_RETURN(per_device_xla_shape, - shape_representation_fn(per_device_tensor_shape, dtype, - use_fast_memory)); - *xla_shape->mutable_layout() = per_device_xla_shape.layout(); - } - return Status::OK(); -} - -// There is a shape_representation_fn or sharding for an output, this function -// uses a reshape to fix the layout. -xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( - xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - absl::optional sharding, bool fast_mem) { - if (original_shape.IsTuple()) { - std::vector elements; - for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) { - auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; - TF_ASSIGN_OR_RETURN(auto element, - ReshapeWithCorrectRepresentationAndSharding( - builder, xla::GetTupleElement(original, i), - original_shape.tuple_shapes(i), - shape_representation_fn, subsharding, fast_mem)); - elements.push_back(element); - } - return xla::Tuple(builder, elements); - } - if (!original_shape.IsArray()) return original; - TensorShape shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape)); - TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( - original_shape.element_type())); - TF_ASSIGN_OR_RETURN(auto to_shape, - shape_representation_fn(shape, dtype, fast_mem)); - if (sharding) { - TF_ASSIGN_OR_RETURN(auto hlo_sharding, - xla::HloSharding::FromProto(*sharding)); - TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( - hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); - } - if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { - for (int64 i = 0; i < original_shape.rank(); ++i) { - to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); - } - } - return xla::Reshape(to_shape, original); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index b95d250636a..b0d93cde846 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -21,8 +21,10 @@ limitations under the License. #include "absl/types/span.h" #include "absl/types/variant.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -97,96 +99,7 @@ class XlaContext; // `tensor_array_gradients` ordered set. class XlaCompiler { public: - // Describes how to derive the value of each _Arg node in the graph/function - // being compiled. There must be one Argument for each _Arg index. - struct Argument { - enum Kind { - // Default value; not a valid kind. - kInvalid, - - // Argument is a compile-time constant. No associated runtime parameter. - kConstant, - - // Argument is a Variable, TensorArray, or Stack resource. Has an - // associated runtime parameter iff `initialized` is true. - kResource, - - // Argument is a run-time parameter. - kParameter, - - // Argument is an XLA token. - kToken, - - // Argument is a TensorList. - kTensorList, - }; - - Kind kind = kInvalid; - - // The type of the argument. If the argument is a resource, this - // is the type of the variable's value, not DT_RESOURCE. - DataType type = DT_INVALID; - - // The shape of the argument. For: - // * a parameter: the shape of the parameter. We allow setting the xla shape - // if known. This helps avoid conversions to and from TensorShape. - // * a constant: ignored; the shape given by constant_value is used - // instead. - // * an uninitialized resource: ignored. We don't yet know the shape of an - // uninitialized resource (otherwise we would have initialized it!) - // * an initialized variable: the shape of the variable's value. - // * an initialized TensorArray or Stack resource: the shape of an entry in - // the TensorArray/Stack. Note this is the size of a single entry, not the - // XLA data structure that represents the complete stack/array. - absl::variant shape; - - // The value of the argument, if it is a compile-time constant. Must be a - // host-memory tensor. - Tensor constant_value; - - // The name of this argument, used for debugging. - string name; - - // The name of TensorFlow _Arg node, used for debugging. - string node_name; - - // For a kResource, what kind of resource is it? - XlaResource::Kind resource_kind = XlaResource::kInvalid; - - // For a kResource, has this resource been initialized? - bool initialized = false; - - // For a kResource, is this resource on Fast Memory. - bool fast_mem = false; - - // For a TensorArray or Stack resource, what is the array's declared size? - // (Used for lazy initialization.) - int64 max_array_size = -1; - - // TensorArray resource parameters are passed as (array, gradient array 0, - // ..., gradient array k), where the gradient arrays are in the same order - // as `tensor_array_gradients`. - std::set tensor_array_gradients; - - // dynamic dims to arg number map. Empty if no dynamic shapes. - std::map dynamic_dim_to_arg_num_map; - bool is_pad_arg = false; - - // Whether this argument will receive the same data across all replicas. - bool is_same_data_across_replicas = false; - - bool operator==(const Argument& other) const; - - // Returns a human-readable summary of the argument. - string HumanString() const; - - // Returns the dimension sizes for either TensorShape or xla::Shape. - std::vector DimensionSizes() const; - absl::InlinedVector DimensionSizesAsInlinedVector() const; - - // Returns the human-readable string for either TensorShape or xla::Shape. - string ShapeHumanString() const; - }; + using Argument = ::tensorflow::XlaArgument; // Options pertaining to an individual call to CompileGraph() or // CompileFunction(). @@ -221,77 +134,11 @@ class XlaCompiler { bool alias_resource_update = false; }; - struct OutputDescription { - // Type and shape of the output. The shape is the unflattened shape. - // When `type` is DT_RESOURCE, `shape` is the shape of the resource - // variable's value. - DataType type; - TensorShape shape; + using OutputDescription = ::tensorflow::XlaOutputDescription; - // Constant output value, if known to be constant at JIT compilation time. - // 'Tensor' is in host memory. - bool is_constant = false; - Tensor constant_value; + using ResourceUpdate = ::tensorflow::XlaResourceUpdate; - // When this output is a resource, i.e. `type == DT_RESOURCE`, this is - // the index of the input that contains the resource. - int input_index; - - // Whether this output is a TensorList. - bool is_tensor_list = false; - }; - - // Describes a variable write side effect of the computation. - struct ResourceUpdate { - // Index of the input that contains the variable resource to write to. - int input_index; - - // Type and shape of the tensor to be written back. - // The `shape` field has the same meaning as the Argument::shape field. - DataType type; - TensorShape shape; - - // Was the value of the variable modified by the computation? - // (Always true, unless `return_updated_values_for_all_resources` is true.) - bool modified; - - // If the resource is a TensorArray, the set of gradients read or written. - std::set tensor_array_gradients_accessed; - }; - - struct CompilationResult { - // Vector that maps from the parameters of the XLA computation to their - // original argument positions. To handle compile-time constant inputs, the - // parameters to the XLA computation may be a subset of the original - // arguments. The relative ordering of parameters are maintained. - std::vector input_mapping; - - // Input shapes of the computation. If we are flattening inputs, these are - // the flattened shapes. - std::vector xla_input_shapes; - - // Output shape in XLA format. The output shape is always a tuple. If we - // are flattening outputs, these are the flattened shapes. - xla::Shape xla_output_shape; - - // TensorFlow shapes of outputs, together with the values of any - // constant arguments. Vector indexed by Tensorflow _Retval number, - // containing both constant and non-constant results. - std::vector outputs; - - // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their - // matching RecvAtHost/SendFromHost Ops in the outer graph. - tf2xla::HostComputeMetadata host_compute_metadata; - - // Resources whose values were updated by the computation, ordered - // by return value position (which is the same as the order the resources - // were passed as arguments). Resource updates follow the non-constant - // results in the outputs of XLA computation. - std::vector resource_updates; - - // The XLA computation built from the tensorflow subgraph. - std::shared_ptr computation; - }; + using CompilationResult = ::tensorflow::XlaCompilationResult; typedef std::function(const TensorShape&, DataType, bool)> @@ -518,21 +365,6 @@ class XlaCompiler { TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; -// Creates an identity shape representation function. -XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn(); - -// Rewrites the layout of xla_shape if there is tiled sharding. -Status RewriteLayoutWithShardedShape( - const absl::optional& sharding, bool use_fast_memory, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - xla::Shape* xla_shape); - -// Adds reshapes to fix the layout of an output, if a shape_representation_fn or -// sharding is present. -xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( - xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - absl::optional sharding, bool fast_mem); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 4f1b6c8e7a9..5df508d60b3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -1856,5 +1856,46 @@ TEST_F(XlaCompilerTest, DoNotConstantFoldShapeOp) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal)); } +TEST_F(XlaCompilerTest, AliasResourceUpdates) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::Const(scope.WithOpName("A"), {1, 2}); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + auto write = ops::AssignAddVariableOp(scope, var, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto d = ops::_Retval(scope.WithOpName("D"), read, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kConstant; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[0].constant_value = Tensor(DT_INT32, {1, 1}); + args[0].initialized = true; + + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompileOptions compile_options; + compile_options.alias_resource_update = true; + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + const xla::HloInputOutputAliasProto& alias = + result.computation->proto().input_output_alias(); + EXPECT_EQ(alias.entries_size(), 1); + EXPECT_EQ(alias.entries(0).parameter_number(), 0); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index c94c4805d53..cb5bf34208f 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index eb4ad3fe6a1..e44ac05b702 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -20,7 +20,6 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -33,6 +32,7 @@ limitations under the License. namespace tensorflow { class XlaOpKernelContext; +class XlaCompiler; // The XlaContext is the data structure that holds the state of an XLA // compilation, that is accessible from OpKernelContexts when compiling a diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 49f108ed6c8..f0cc8d26709 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -101,6 +101,48 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { }); } +xla::StatusOr XlaExpression::ResolveDynamism( + xla::Client* client) const { + switch (kind()) { + case Kind::kConstant: { + // Constant values are considered static. + Tensor constant_false(DT_BOOL, constant_value().shape()); + auto flat = constant_false.flat(); + for (int64 i = 0; i < flat.size(); ++i) flat(i) = false; + return constant_false; + } + case Kind::kXlaOp: + break; + case Kind::kTensorList: + TF_FALLTHROUGH_INTENDED; + case Kind::kResource: + TF_FALLTHROUGH_INTENDED; + case Kind::kInvalid: + return errors::InvalidArgument( + "ResolveDynamism called on unsupported XlaExpression: ", + HumanString()); + } + + if (!client) + return errors::InvalidArgument("client is required to resolve constant"); + + TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, + handle().builder()->BuildDynamicInferenceGraph(handle())); + + TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); + + // The XLA layout is specified minor to major, and TensorFlow uses a major to + // minor order. + std::vector layout_indices(shape.dims()); + std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); + xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + TF_ASSIGN_OR_RETURN(xla::Literal literal, + client->ComputeConstant(constant_graph, &layout)); + Tensor tensor(DT_BOOL); + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, DT_BOOL, &tensor)); + return tensor; +} + xla::StatusOr> XlaExpression::ResolveConstant( xla::Client* client, bool dynamic_dimension_is_minus_one) const { switch (kind()) { @@ -163,4 +205,23 @@ xla::StatusOr XlaExpression::GetShape() const { } } +const XlaExpression* XlaExpression::CastExpressionFromTensor( + const Tensor& tensor) { + const XlaExpression* expression = + reinterpret_cast(tensor.tensor_data().data()); + CHECK(expression->kind() != XlaExpression::Kind::kInvalid) + << expression->HumanString(); + return expression; +} + +// Assigns an XlaExpression to a tensor on an XLA compilation device. +void XlaExpression::AssignExpressionToTensor(const XlaExpression& value, + Tensor* tensor) { + const XlaExpression* expression = + reinterpret_cast(tensor->tensor_data().data()); + CHECK(expression->kind() == XlaExpression::Kind::kInvalid) + << expression->HumanString(); + *const_cast(expression) = value; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index 5d0bb35b182..3546368ff7b 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -99,11 +99,22 @@ class XlaExpression { xla::StatusOr> ResolveConstant( xla::Client* client, bool dynamic_dimension_is_minus_one = false) const; + // ResolveDynamism computes where a value inside this op is dynamic or can be + // inferred at compile time. + xla::StatusOr ResolveDynamism(xla::Client* client) const; + // Returns the shape of the tensor. // The shape of a resource is the shape of a resource handle (i.e., a scalar), // not the shape of the resource's value. xla::StatusOr GetShape() const; + // Retrieves an XlaExpression that was allocated by a previous Op. + static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); + + // Assigns an XlaExpression to a tensor on an XLA compilation device. + static void AssignExpressionToTensor(const XlaExpression& value, + Tensor* tensor); + private: Kind kind_ = Kind::kInvalid; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 74247bbaec7..8c4b55aec8a 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -22,8 +22,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -128,4 +126,93 @@ xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand, return xla::ConvertElementType(operand, convert_to); } +XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() { + return [](const TensorShape& shape, DataType dtype, + bool use_fast_memory) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; +} + +// Rewrites the layout of xla_shape if there is tiled sharding. +Status RewriteLayoutWithShardedShape( + const absl::optional& sharding, bool use_fast_memory, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_shape) { + if (sharding && !sharding->IsTileMaximal()) { + // After sharding, per core shape might have different layout. For example, + // before sharding, a shape [128, 128] will be assigned default + // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2, + // the sharded shapes will have minor-to-major {0, 1}. + // + // As a result, for sharded shapes, we set their layout to per core shape's + // layout. + // + // TODO(endlessroad): for variable input & update, we might have + // different layouts which will prevent input output aliasing and + // increase memory usage. Investigate such cases. + int64 device = *sharding->tile_assignment().begin(); + std::vector offset = + sharding->TileOffsetForDevice(*xla_shape, device); + std::vector limit = sharding->TileLimitForDevice(*xla_shape, device); + std::vector dimensions(xla_shape->rank()); + for (int64 i = 0; i < xla_shape->rank(); ++i) { + dimensions[i] = limit[i] - offset[i]; + } + xla::Shape per_device_xla_shape = + xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); + TensorShape per_device_tensor_shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + xla_shape->element_type())); + TF_ASSIGN_OR_RETURN(per_device_xla_shape, + shape_representation_fn(per_device_tensor_shape, dtype, + use_fast_memory)); + *xla_shape->mutable_layout() = per_device_xla_shape.layout(); + } + return Status::OK(); +} + +// There is a shape_representation_fn or sharding for an output, this function +// uses a reshape to fix the layout. +xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + absl::optional sharding, bool fast_mem) { + if (original_shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) { + auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; + TF_ASSIGN_OR_RETURN(auto element, + ReshapeWithCorrectRepresentationAndSharding( + builder, xla::GetTupleElement(original, i), + original_shape.tuple_shapes(i), + shape_representation_fn, subsharding, fast_mem)); + elements.push_back(element); + } + return xla::Tuple(builder, elements); + } + if (!original_shape.IsArray()) return original; + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + original_shape.element_type())); + TF_ASSIGN_OR_RETURN(auto to_shape, + shape_representation_fn(shape, dtype, fast_mem)); + if (sharding) { + TF_ASSIGN_OR_RETURN(auto hlo_sharding, + xla::HloSharding::FromProto(*sharding)); + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( + hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); + } + if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { + for (int64 i = 0; i < original_shape.rank(); ++i) { + to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); + } + } + return xla::Reshape(to_shape, original); +} + } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 490923526bd..3a9375ec1f4 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -19,8 +19,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #include "absl/types/span.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { @@ -72,6 +73,98 @@ class XlaHelpers { // than the xla::PrimitiveType. static xla::XlaOp ConvertElementType(const xla::XlaOp& operand, const DataType new_element_type); + + typedef std::function(const TensorShape&, DataType, + bool)> + ShapeRepresentationFn; +}; + +// Creates an identity shape representation function. +XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn(); + +// Rewrites the layout of xla_shape if there is tiled sharding. +Status RewriteLayoutWithShardedShape( + const absl::optional& sharding, bool use_fast_memory, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_shape); + +// Adds reshapes to fix the layout of an output, if a shape_representation_fn or +// sharding is present. +xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + absl::optional sharding, bool fast_mem); + +struct XlaOutputDescription { + // Type and shape of the output. The shape is the unflattened shape. + // When `type` is DT_RESOURCE, `shape` is the shape of the resource + // variable's value. + DataType type; + TensorShape shape; + + // Constant output value, if known to be constant at JIT compilation time. + // 'Tensor' is in host memory. + bool is_constant = false; + Tensor constant_value; + + // When this output is a resource, i.e. `type == DT_RESOURCE`, this is + // the index of the input that contains the resource. + int input_index; + + // Whether this output is a TensorList. + bool is_tensor_list = false; +}; + +// Describes a variable write side effect of the computation. +struct XlaResourceUpdate { + // Index of the input that contains the variable resource to write to. + int input_index; + + // Type and shape of the tensor to be written back. + // The `shape` field has the same meaning as the Argument::shape field. + DataType type; + TensorShape shape; + + // Was the value of the variable modified by the computation? + // (Always true, unless `return_updated_values_for_all_resources` is true.) + bool modified; + + // If the resource is a TensorArray, the set of gradients read or written. + std::set tensor_array_gradients_accessed; +}; + +struct XlaCompilationResult { + // Vector that maps from the parameters of the XLA computation to their + // original argument positions. To handle compile-time constant inputs, the + // parameters to the XLA computation may be a subset of the original + // arguments. The relative ordering of parameters are maintained. + std::vector input_mapping; + + // Input shapes of the computation. If we are flattening inputs, these are + // the flattened shapes. + std::vector xla_input_shapes; + + // Output shape in XLA format. The output shape is always a tuple. If we + // are flattening outputs, these are the flattened shapes. + xla::Shape xla_output_shape; + + // TensorFlow shapes of outputs, together with the values of any + // constant arguments. Vector indexed by Tensorflow _Retval number, + // containing both constant and non-constant results. + std::vector outputs; + + // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their + // matching RecvAtHost/SendFromHost Ops in the outer graph. + tf2xla::HostComputeMetadata host_compute_metadata; + + // Resources whose values were updated by the computation, ordered + // by return value position (which is the same as the order the resources + // were passed as arguments). Resource updates follow the non-constant + // results in the outputs of XLA computation. + std::vector resource_updates; + + // The XLA computation built from the tensorflow subgraph. + std::shared_ptr computation; }; } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 27766408716..07537546d52 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -49,33 +49,13 @@ XlaCompiler* XlaOpKernelContext::compiler() const { return xla_context()->compiler(); } -// Retrieves an XlaExpression that was allocated by a previous Op. -const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor( - const Tensor& tensor) { - const XlaExpression* expression = - reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->kind() != XlaExpression::Kind::kInvalid) - << expression->HumanString(); - return expression; -} - -// Assigns an XlaExpression to a tensor on an XLA compilation device. -void XlaOpKernelContext::AssignExpressionToTensor(const XlaExpression& value, - Tensor* tensor) { - const XlaExpression* expression = - reinterpret_cast(tensor->tensor_data().data()); - CHECK(expression->kind() == XlaExpression::Kind::kInvalid) - << expression->HumanString(); - *const_cast(expression) = value; -} - const XlaExpression& XlaOpKernelContext::InputExpression(int index) { - return *CastExpressionFromTensor(context_->input(index)); + return *XlaExpression::CastExpressionFromTensor(context_->input(index)); } const XlaExpression& XlaOpKernelContext::InputExpression( absl::string_view name) { - return *CastExpressionFromTensor(GetInputTensorByName(name)); + return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name)); } xla::XlaOp XlaOpKernelContext::Input(int index) { @@ -108,7 +88,8 @@ DataType XlaOpKernelContext::input_type(int index) const { if (type == DT_UINT8) { // Masqueraded XlaExpression could have different type. See // XlaOpKernelContext::SetOutputExpression for details. - auto expression = CastExpressionFromTensor(context_->input(index)); + auto expression = + XlaExpression::CastExpressionFromTensor(context_->input(index)); type = expression->dtype(); } return type; @@ -120,7 +101,7 @@ DataType XlaOpKernelContext::InputType(absl::string_view name) { if (type == DT_UINT8) { // Masqueraded XlaExpression could have different type. See // XlaOpKernelContext::SetOutputExpression for details. - auto expression = CastExpressionFromTensor(tensor); + auto expression = XlaExpression::CastExpressionFromTensor(tensor); type = expression->dtype(); } return type; @@ -262,6 +243,48 @@ Status XlaOpKernelContext::ConstantInputAsFloatScalar(int index, double* out) { return LiteralToFloat64Scalar(literal, out); } +static Status LiteralToPredVector(const xla::LiteralSlice& literal, + std::vector* out) { + if (literal.shape().rank() != 1) { + return errors::InvalidArgument("value is not 1D, rank: ", + literal.shape().rank()); + } + int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); + if (literal.shape().element_type() != xla::PRED) { + return errors::InvalidArgument("value is not PRED"); + } + for (int64 i = 0; i < size; ++i) { + out->push_back(literal.Get({i})); + } + return Status::OK(); +} + +Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( + int index, std::vector* out) { + xla::Literal literal; + XlaExpression e = InputExpression(index); + auto* client = compiler() ? compiler()->client() : nullptr; + xla::StatusOr dynamism_or_status = e.ResolveDynamism(client); + if (!dynamism_or_status.ok()) { + Status status = dynamism_or_status.status(); + errors::AppendToMessage(&status, "while evaluating input dynamism", index, + " of ", context_->op_kernel().type_string()); + return status; + } + Tensor dynamism = dynamism_or_status.ValueOrDie(); + + Tensor temp(dynamism.dtype()); + TensorShape tensor_shape({InputShape(index).num_elements()}); + if (!temp.CopyFrom(dynamism, tensor_shape)) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + dynamism.shape().DebugString(), " which is not a R1 ", tensor_shape); + } + + TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp)); + return LiteralToPredVector(literal, out); +} + // Converts an int32 or int64 1D literal to an int64 vector. static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { @@ -385,7 +408,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name, handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { - handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder())); + handles->push_back( + XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder())); shapes->push_back(input.shape()); } return Status::OK(); @@ -408,7 +432,7 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, const XlaOpKernelContext* ctx, TensorShape* shape, xla::XlaOp* value) { const XlaExpression* expression = - XlaOpKernelContext::CastExpressionFromTensor(tensor); + XlaExpression::CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); @@ -461,7 +485,8 @@ Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, TensorShape* shape) const { const Tensor& tensor = context_->input(index); - const XlaExpression* expression = CastExpressionFromTensor(tensor); + const XlaExpression* expression = + XlaExpression::CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); @@ -502,8 +527,8 @@ void XlaOpKernelContext::SetOutputExpression(int index, TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); } - XlaOpKernelContext::AssignExpressionToTensor( - expression, context_->mutable_output(index)); + XlaExpression::AssignExpressionToTensor(expression, + context_->mutable_output(index)); return Status::OK(); }(); if (!status.ok()) { @@ -542,7 +567,7 @@ void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { const XlaExpression* expression = - CastExpressionFromTensor(context_->input(index)); + XlaExpression::CastExpressionFromTensor(context_->input(index)); TF_RET_CHECK(expression->resource() != nullptr); *resource = expression->resource(); return Status::OK(); @@ -554,7 +579,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, const XlaOpKernelContext* ctx, xla::XlaOp handle, xla::XlaBuilder* builder) { const XlaExpression* expression = - XlaOpKernelContext::CastExpressionFromTensor(tensor); + XlaExpression::CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 6987b6fbb98..75c3e60171a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -17,6 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -113,6 +116,9 @@ class XlaOpKernelContext { // returns a one-element list. Status InputList(absl::string_view name, std::vector* handles, std::vector* shapes); + // Evaluates input and returns their dynamism vector in a vector of + // predicates. + Status ResolveInputDynamismIntoPredVector(int index, std::vector* out); // Helper methods for constant inputs. @@ -284,13 +290,6 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::XlaComputation* GetOrCreateMul(const DataType type); - // Assigns an XlaExpression to a tensor on an XLA compilation device. - static void AssignExpressionToTensor(const XlaExpression& value, - Tensor* tensor); - - // Retrieves an XlaExpression that was assigned to the specified tensor. - static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); - private: // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 32d42cb8a42..bec0b46611d 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/xla_builder.h" diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 67bad0f8af7..a85d551769c 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -289,13 +289,19 @@ class Array { } // Fills the array with random normal variables with the specified mean. - void FillRandom(const T& stddev, const double mean = 0.0, - const int seed = 12345) { + void FillRandom(const T& stddev, double mean = 0.0, int seed = 12345) { + FillRandomDouble(static_cast(stddev), mean, seed); + } + + void FillRandomDouble(double stddev, double mean = 0.0, int seed = 12345) { std::mt19937 g(seed); - std::normal_distribution distribution(mean, - static_cast(stddev)); + std::normal_distribution distribution(mean, stddev); for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = static_cast(distribution(g)); + if (std::is_same()) { + values_[i] = static_cast(distribution(g) > 0.0); + } else { + values_[i] = static_cast(distribution(g)); + } } } @@ -403,7 +409,8 @@ class Array { // Returns the size of the dimension at the given index. int64 dim(int64 n) const { - CHECK(n < sizes_.size()); + const int64 sizes_size = sizes_.size(); + CHECK(n < sizes_size); return sizes_[n]; } @@ -427,7 +434,7 @@ class Array { if (sizes_.size() != other.sizes_.size()) { return false; } - for (int64 i = 0; i < sizes_.size(); ++i) { + for (int64 i = 0, end = sizes_.size(); i < end; ++i) { if (sizes_[i] != other.sizes_[i]) { return false; } diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 4f020bcec27..09449aeb8b8 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -312,7 +312,7 @@ StatusOr> Client::Execute( // device 0. // // TODO(b/118493728): Allow Execute to return one result per computation. - for (int64 i = 0; i < results.size(); i++) { + for (int64 i = 0, end = results.size(); i < end; i++) { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i])); if (!ShapeUtil::IsEmptyTuple(shape)) { VLOG(3) << "Fetching result from device " << i << ": " @@ -350,7 +350,7 @@ StatusOr>> Client::ExecuteParallel( } std::vector> outputs; - for (size_t i = 0; i < response.responses_size(); ++i) { + for (size_t i = 0, end = response.responses_size(); i < end; ++i) { outputs.push_back( absl::make_unique(stub_, response.responses(i).output())); if (i < computations.size() && diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 404f9eb7519..f39a3e79fe5 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -76,6 +76,12 @@ ExecutableBuildOptions& ExecutableBuildOptions::set_use_spmd_partitioning( return *this; } +ExecutableBuildOptions& ExecutableBuildOptions::set_deduplicate_hlo( + bool deduplicate_hlo) { + deduplicate_hlo_ = deduplicate_hlo; + return *this; +} + ExecutableBuildOptions& ExecutableBuildOptions::set_device_assignment( const DeviceAssignment& device_assignment) { device_assignment_ = device_assignment; diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 9a7fdd974b1..d034eaa7fd6 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -82,6 +82,9 @@ class ExecutableBuildOptions { bool use_spmd_partitioning() const { return use_spmd_partitioning_; } ExecutableBuildOptions& set_use_spmd_partitioning(bool use_spmd_partitioning); + bool deduplicate_hlo() const { return deduplicate_hlo_; } + ExecutableBuildOptions& set_deduplicate_hlo(bool deduplicate_hlo); + // If set, this specifies a static device assignment for the computation. // Otherwise, the computation will be compiled generically and can be run with // any device assignment compatible with the computation's replica and @@ -110,6 +113,7 @@ class ExecutableBuildOptions { int num_replicas_ = 1; int num_partitions_ = 1; bool use_spmd_partitioning_ = false; + bool deduplicate_hlo_ = false; absl::optional device_assignment_; bool alias_passthrough_params_ = false; }; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 06fd8ceeb2b..a3c7c39e3ff 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -55,9 +55,13 @@ xla_test( cc_library( name = "comparators", srcs = ["comparators.cc"], - hdrs = ["comparators.h"], + hdrs = [ + "comparators.h", + "//tensorflow/compiler/xla:literal_util", + ], deps = [ ":constants", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto_cc", diff --git a/tensorflow/compiler/xla/client/lib/comparators.cc b/tensorflow/compiler/xla/client/lib/comparators.cc index 74e89b767cf..cd594a5cf39 100644 --- a/tensorflow/compiler/xla/client/lib/comparators.cc +++ b/tensorflow/compiler/xla/client/lib/comparators.cc @@ -32,85 +32,13 @@ limitations under the License. namespace xla { namespace { -using XlaOpGenerator = XlaOp (*)(XlaOp, XlaOp, absl::Span); - -XlaOp BitcastConvertFloatingPointToIntegral(const XlaOp& value, - int64 bit_width) { - PrimitiveType signed_type; - PrimitiveType unsigned_type; - XlaOp max_value; - switch (bit_width) { - case 16: - max_value = - ConstantR0(value.builder(), - static_cast(std::numeric_limits::max())); - signed_type = S16; - unsigned_type = U16; - break; - case 32: - max_value = - ConstantR0(value.builder(), - static_cast(std::numeric_limits::max())); - signed_type = S32; - unsigned_type = U32; - break; - case 64: - max_value = - ConstantR0(value.builder(), - static_cast(std::numeric_limits::max())); - signed_type = S64; - unsigned_type = U64; - break; - default: - return value.builder()->ReportError( - InvalidArgument("Invalid bit width %lld for Comparator floating " - "point parameter.", - bit_width)); - } - // Switch from a floating point value to a integer value in such a way that - // when using the integer value to compare, we get the same result for normal - // values, and -Nan is treated as the smallest value, and Nan is treated as - // the largest value. - // If f is a float, and - // x = bit_cast(f); - // y = x < 0 ? numeric_limits::max() - x : x; - // then y is ordered as an int32 such that finite values have the obvious - // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning - // and end of the ordering. - // Note that in order to avoid -x to overflow, we calculate - // numeric_limits::max() - x as unsigned, and then convert back to - // signed. - auto signed_value = BitcastConvertType(value, signed_type); - auto unsigned_value = BitcastConvertType(value, unsigned_type); - auto flipped_value = - BitcastConvertType(Sub(max_value, unsigned_value), signed_type); - auto is_negative = Lt(signed_value, Zero(value.builder(), signed_type)); - return Select(is_negative, flipped_value, signed_value); -} - -void ConvertFloatingPoint(const PrimitiveType& operand_type, XlaOp* lhs_param, - XlaOp* rhs_param) { - if (primitive_util::IsFloatingPointType(operand_type)) { - PrimitiveType compare_type = operand_type; - // Special-case handling for BF16. We currently do not support direct - // comparisons with BF16, so we convert to F32 and then use the F32 - // comparison logic. - if (compare_type == BF16) { - compare_type = F32; - *lhs_param = ConvertElementType(*lhs_param, F32); - *rhs_param = ConvertElementType(*rhs_param, F32); - } - int64 bit_width = primitive_util::BitWidth(compare_type); - *lhs_param = BitcastConvertFloatingPointToIntegral(*lhs_param, bit_width); - *rhs_param = BitcastConvertFloatingPointToIntegral(*rhs_param, bit_width); - } -} +using XlaCompareOp = XlaOp (*)(XlaOp, XlaOp, absl::Span); XlaComputation CreateScalarComparisonComputation( const string& name, const std::vector& operand_types, - XlaBuilder* builder, XlaOpGenerator generator) { + XlaBuilder* builder, XlaCompareOp generator) { CHECK_NE(operand_types.size(), 0); - std::vector> generators(operand_types.size()); + std::vector> generators(operand_types.size()); generators[0] = generator; return CreateScalarComparisonComputation(name, operand_types, generators, builder); @@ -119,7 +47,7 @@ XlaComputation CreateScalarComparisonComputation( XlaComputation CreateScalarComparisonComputation( const string& name, const std::vector& operand_types, - const std::vector>& generators, + const std::vector>& generators, XlaBuilder* builder) { // Create a default computation where we compare only the first two // parameters of type 'operand_types[0]'. @@ -146,7 +74,6 @@ XlaComputation CreateScalarComparisonComputation( absl::StrCat("p.", parameter_count, ".lhs")); auto rhs_param = Parameter(b.get(), parameter_count * 2 + 1, scalar_shape, absl::StrCat("p.", parameter_count, ".rhs")); - ConvertFloatingPoint(operand_type, &lhs_param, &rhs_param); lhs_params.emplace_back(lhs_param); rhs_params.emplace_back(rhs_param); if (generators[parameter_count].has_value()) { @@ -169,7 +96,8 @@ XlaComputation CreateScalarComparisonComputation( generators[i].value()(lhs_params[i], rhs_params[i], {}), result); if (i != last_generator_index) { - param_equal = And(param_equal, Eq(lhs_params[i], rhs_params[i])); + param_equal = + And(param_equal, EqTotalOrder(lhs_params[i], rhs_params[i])); } } } @@ -181,14 +109,14 @@ XlaComputation CreateScalarComparisonComputation( XlaComputation CreateScalarLtComputation( const std::vector& operand_types, XlaBuilder* builder) { return CreateScalarComparisonComputation("compare-less-than", operand_types, - builder, Lt); + builder, LtTotalOrder); } // Creates a scalar greater-than computation and returns it. XlaComputation CreateScalarGtComputation( const std::vector& operand_types, XlaBuilder* builder) { - return CreateScalarComparisonComputation("compare-greater-than", - operand_types, builder, Gt); + return CreateScalarComparisonComputation( + "compare-greater-than", operand_types, builder, GtTotalOrder); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/comparators.h b/tensorflow/compiler/xla/client/lib/comparators.h index 25924d4a4f4..a82a84799aa 100644 --- a/tensorflow/compiler/xla/client/lib/comparators.h +++ b/tensorflow/compiler/xla/client/lib/comparators.h @@ -43,14 +43,13 @@ XlaComputation CreateScalarGtComputation( const std::vector& operand_types, XlaBuilder* builder); // Creates a scalar comparison computation and returns it. This function takes -// an std::vector> and compare the operands -// where the generator isn't nullopt with the specified comparator -// at that location. +// a vector of comparator functions to compare the operands where the function +// isn't nullopt with the specified comparator at that location. XlaComputation CreateScalarComparisonComputation( const string& name, const std::vector& operand_types, const std::vector< absl::optional)>>& - generators, + comparators, XlaBuilder* builder); } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc index 6bd56a8df0a..4836dff7fa0 100644 --- a/tensorflow/compiler/xla/client/lib/constants.cc +++ b/tensorflow/compiler/xla/client/lib/constants.cc @@ -48,7 +48,9 @@ XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) { builder, static_cast(Eigen::NumTraits::epsilon())); case BF16: - return ConstantR0(builder, bfloat16::epsilon()); + return ConstantR0( + builder, static_cast( + Eigen::NumTraits::epsilon())); case F32: return ConstantR0(builder, std::numeric_limits::epsilon()); case F64: @@ -70,7 +72,8 @@ XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) { return ConstantR0(builder, Eigen::NumTraits::lowest()); case BF16: - return ConstantR0(builder, bfloat16::lowest()); + return ConstantR0( + builder, Eigen::NumTraits::lowest()); case F32: return ConstantR0(builder, -std::numeric_limits::max()); case F64: @@ -86,7 +89,8 @@ XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) { return ConstantR0(builder, std::numeric_limits::min()); case BF16: - return ConstantR0(builder, bfloat16::min_positive_normal()); + return ConstantR0( + builder, std::numeric_limits::min()); case F32: return ConstantR0(builder, std::numeric_limits::min()); case F64: @@ -108,7 +112,8 @@ XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) { return ConstantR0(builder, Eigen::NumTraits::highest()); case BF16: - return ConstantR0(builder, bfloat16::highest()); + return ConstantR0( + builder, Eigen::NumTraits::highest()); case F32: return ConstantR0(builder, std::numeric_limits::max()); case F64: @@ -125,8 +130,8 @@ XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) { return ConstantR0( builder, Eigen::NumTraits::quiet_NaN()); case BF16: - return ConstantR0( - builder, bfloat16(std::numeric_limits::quiet_NaN())); + return ConstantR0( + builder, Eigen::NumTraits::quiet_NaN()); case F32: return ConstantR0(builder, std::numeric_limits::quiet_NaN()); diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index baafd7d705b..6fdaab58686 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -511,7 +511,7 @@ XlaOp Lgamma(XlaOp input) { XlaOp z = Select(need_to_reflect, -input, input - one); XlaOp x = base_lanczos_coeff; - for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); XlaOp index = ScalarLike(input, i); x = x + lanczos_coefficient / (z + index + one); @@ -647,7 +647,7 @@ XlaOp Digamma(XlaOp input) { XlaOp num = zero; XlaOp denom = base_lanczos_coeff; - for (int i = 0; i < kLanczosCoefficients.size(); ++i) { + for (int i = 0, end = kLanczosCoefficients.size(); i < end; ++i) { XlaOp lanczos_coefficient = ScalarLike(input, kLanczosCoefficients[i]); XlaOp index = ScalarLike(input, i); num = num - lanczos_coefficient / ((z + index + one) * (z + index + one)); diff --git a/tensorflow/compiler/xla/client/lib/pooling.cc b/tensorflow/compiler/xla/client/lib/pooling.cc index 45033ec07e7..fb04b147ff2 100644 --- a/tensorflow/compiler/xla/client/lib/pooling.cc +++ b/tensorflow/compiler/xla/client/lib/pooling.cc @@ -198,15 +198,17 @@ XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, XlaBuilder* b = out_backprop.builder(); return b->ReportErrorOrReturn([&]() -> StatusOr { const int num_dims = kernel_size.size(); - - if (gradients_size.size() != num_dims) { + const int num_gradients = gradients_size.size(); + if (num_gradients != num_dims) { return tensorflow::errors::InvalidArgument("gradients must be ", num_dims, "-dimensional"); } TF_ASSIGN_OR_RETURN(Shape out_backprop_xla_shape, b->GetShape(out_backprop)); - if (out_backprop_xla_shape.dimensions().size() != num_dims) { + const int backprop_xla_num_dims = + out_backprop_xla_shape.dimensions().size(); + if (backprop_xla_num_dims != num_dims) { return tensorflow::errors::InvalidArgument("out_backprop must be ", num_dims, "-dimensional"); } diff --git a/tensorflow/compiler/xla/client/lib/slicing.cc b/tensorflow/compiler/xla/client/lib/slicing.cc index 1ea713467f8..ebb35c5df82 100644 --- a/tensorflow/compiler/xla/client/lib/slicing.cc +++ b/tensorflow/compiler/xla/client/lib/slicing.cc @@ -74,12 +74,13 @@ XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start) { return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(x)); const int64 n_dims = shape.rank(); - TF_RET_CHECK(start.size() == n_dims); + const int64 start_size = start.size(); + TF_RET_CHECK(start_size == n_dims); // TODO(phawkins): make int64 work on all backends, remove the int32 cast. std::vector start_as_int32(start.begin(), start.end()); std::vector start_ops(start.size()); - for (int i = 0; i < start.size(); ++i) { + for (int i = 0, end = start.size(); i < end; ++i) { start_ops[i] = ConstantR0(builder, start_as_int32[i]); } return DynamicUpdateSlice(x, update, start_ops); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 5fc9909fa2a..1389f548c5d 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -122,12 +122,13 @@ LocalExecutable::RunHelper(const absl::Span argument_shapes, executable_->module_config().entry_computation_layout(); // Check argument number, shapes, and layouts. - if (argument_shapes.size() != computation_layout.parameter_count()) { + const int argument_shapes_size = argument_shapes.size(); + if (argument_shapes_size != computation_layout.parameter_count()) { return InvalidArgument( "invalid number of arguments for computation: expected %d, got %u", computation_layout.parameter_count(), argument_shapes.size()); } - for (int i = 0; i < argument_shapes.size(); ++i) { + for (int i = 0, end = argument_shapes.size(); i < end; ++i) { if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( *argument_shapes[i])) { return InvalidParameterArgument( @@ -187,7 +188,7 @@ StatusOr LocalExecutable::Run( std::vector argument_shapes; argument_shapes.reserve(arguments.size()); for (const ExecutionInput& arg : arguments) { - argument_shapes.push_back(&arg.shape()); + argument_shapes.push_back(&arg.host_shape()); } return AsyncCallAndBlockHostUntilDone( argument_shapes, run_options, [&](const ExecutableRunOptions& options) { @@ -325,7 +326,7 @@ StatusOr LocalExecutable::RunAsync( std::vector argument_shapes; argument_shapes.reserve(arguments.size()); for (const ExecutionInput& arg : arguments) { - argument_shapes.push_back(&arg.shape()); + argument_shapes.push_back(&arg.host_shape()); } return RunAsync(argument_shapes, std::move(arguments), run_options); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 8b91f4a1739..bb072a0fe2c 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -64,10 +64,6 @@ class LocalExecutable { // Similar to RunAsync(), but allows for donating argument buffers to the // executable. - StatusOr RunAsync( - absl::Span argument_host_shapes, - std::vector arguments, ExecutableRunOptions run_options); - StatusOr RunAsync(std::vector arguments, ExecutableRunOptions run_options); @@ -78,6 +74,10 @@ class LocalExecutable { Executable* executable() const { return executable_.get(); } private: + StatusOr RunAsync( + absl::Span argument_host_shapes, + std::vector arguments, ExecutableRunOptions run_options); + // Validates that the given arguments and options satisfy various constraints // of the computation. // diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index cc6a680c4e9..2b69c71042d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/errors.h" namespace xla { @@ -71,8 +73,75 @@ void SetProtoIdAndName(T* entry, const string& base_name, char separator, entry->set_id(id); entry->set_name(GetFullName(base_name, separator, id)); } + +ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) { + return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto(); +} + +HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape, + bool pred) { + HloInstructionProto const_instr; + Literal literal = LiteralUtil::CreateR0(pred); + Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie(); + *const_instr.mutable_shape() = shape.ToProto(); + *const_instr.mutable_literal() = literal_broadcast.ToProto(); + *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); + const_instr.set_id(id); + return const_instr; +} + +// Converts a HloComputation into ReducerOr with predicate types. +HloComputationProto CreateReduceOr(int64 reducer_id, + HloComputationProto* original_reducer) { + HloComputationProto reducer; + SetProtoIdAndName(&reducer, StrCat("reduce_or"), kNameSeparator, reducer_id); + std::vector operands_id; + for (auto& inst : original_reducer->instructions()) { + // Copy params. + if (StringToHloOpcode(inst.opcode()).ValueOrDie() == + HloOpcode::kParameter) { + HloInstructionProto* new_param = reducer.add_instructions(); + *new_param = inst; + *new_param->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); + operands_id.push_back(inst.id()); + } + if (inst.id() == original_reducer->root_id()) { + HloInstructionProto* new_root = reducer.add_instructions(); + *new_root = inst; + *new_root->mutable_shape() = ConvertShapeProtoToPred(inst.shape()); + *new_root->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); + new_root->clear_operand_ids(); + for (int64 operand_id : operands_id) { + new_root->add_operand_ids(operand_id); + } + reducer.set_root_id(inst.id()); + } + } + return reducer; +} } // namespace +namespace internal { + +XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, + absl::Span operands, + absl::string_view fusion_kind, + const XlaComputation& fused_computation) { + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + instr.set_fusion_kind(std::string(fusion_kind)); + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(auto program_shape, + fused_computation.GetProgramShape()); + *instr.mutable_shape() = program_shape.result().ToProto(); + builder->AddCalledComputation(fused_computation, &instr); + return builder->AddInstruction(std::move(instr), HloOpcode::kFusion, + operands); + }); +} + +} // namespace internal + XlaOp operator-(XlaOp x) { return Neg(x); } XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); } XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); } @@ -425,7 +494,7 @@ StatusOr XlaBuilder::Build(int64 root_id, alias.param_index.ToString().c_str()); } TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number, - alias.param_index)); + alias.param_index, alias.kind)); } *module->mutable_input_output_alias() = config.ToProto(); return Status::OK(); @@ -508,7 +577,8 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, XlaOp operand) { XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, - absl::optional direction) { + absl::optional direction, + absl::optional type) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape* lhs_shape, GetShapePtr(lhs)); TF_ASSIGN_OR_RETURN(const Shape* rhs_shape, GetShapePtr(rhs)); @@ -566,7 +636,11 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, return InvalidArgument( "kCompare expects a ComparisonDirection, but none provided."); } - return Compare(shape, updated_lhs, updated_rhs, *direction); + if (type == absl::nullopt) { + return Compare(shape, updated_lhs, updated_rhs, *direction); + } else { + return Compare(shape, updated_lhs, updated_rhs, *direction, *type); + } } if (direction.has_value()) { @@ -589,8 +663,16 @@ XlaOp XlaBuilder::BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction) { + return Compare(shape, lhs, rhs, direction, + Comparison::DefaultComparisonType(shape.element_type())); +} + +StatusOr XlaBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, + ComparisonDirection direction, + Comparison::Type type) { HloInstructionProto instr; instr.set_comparison_direction(ComparisonDirectionToString(direction)); + instr.set_comparison_type(ComparisonTypeToString(type)); *instr.mutable_shape() = shape.ToProto(); return AddInstruction(std::move(instr), HloOpcode::kCompare, {lhs, rhs}); } @@ -766,15 +848,16 @@ XlaOp XlaBuilder::BroadcastInDim( TF_ASSIGN_OR_RETURN(auto output_shape, ShapeUtil::MakeValidatedShape( operand_shape->element_type(), out_dim_size)); - if (operand_shape->rank() != broadcast_dimensions.size()) { + tensorflow::int64 broadcast_rank = broadcast_dimensions.size(); + if (operand_shape->rank() != broadcast_rank) { return InvalidArgument( "Size of broadcast_dimensions has to match operand's rank; operand " "rank: %lld, size of broadcast_dimensions %u.", operand_shape->rank(), broadcast_dimensions.size()); } - for (int i = 0; i < broadcast_dimensions.size(); i++) { - if (broadcast_dimensions[i] < 0 || - broadcast_dimensions[i] > out_dim_size.size()) { + for (int i = 0; i < broadcast_rank; i++) { + const tensorflow::int64 num_dims = out_dim_size.size(); + if (broadcast_dimensions[i] < 0 || broadcast_dimensions[i] > num_dims) { return InvalidArgument("Broadcast dimension %lld is out of bound", broadcast_dimensions[i]); } @@ -786,7 +869,7 @@ XlaOp XlaBuilder::BroadcastInDim( *operand_shape, output_shape, broadcast_dimensions) .status()); std::vector in_dim_size(out_dim_size.begin(), out_dim_size.end()); - for (int i = 0; i < broadcast_dimensions.size(); i++) { + for (int i = 0; i < broadcast_rank; i++) { in_dim_size[broadcast_dimensions[i]] = operand_shape->dimensions(i); } const auto& in_dim_shape = @@ -835,7 +918,7 @@ StatusOr XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand, absl::Span strides) { HloInstructionProto instr; *instr.mutable_shape() = shape.ToProto(); - for (int i = 0; i < start_indices.size(); i++) { + for (int i = 0, end = start_indices.size(); i < end; i++) { auto* slice_config = instr.add_slice_dimensions(); slice_config->set_start(start_indices[i]); slice_config->set_limit(limit_indices[i]); @@ -1543,7 +1626,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { if (tokens.empty()) { return InvalidArgument("AfterAll requires at least one operand"); } - for (int i = 0; i < tokens.size(); ++i) { + for (int i = 0, end = tokens.size(); i < end; ++i) { XlaOp operand = tokens[i]; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); if (!operand_shape->IsToken()) { @@ -1706,8 +1789,6 @@ XlaOp XlaBuilder::Sort(absl::Span operands, const XlaComputation& comparator, int64 dimension, bool is_stable) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - instr.set_is_stable(is_stable); std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(std::vector operand_shapes, GetOperandShapes(operands)); @@ -1715,17 +1796,26 @@ XlaOp XlaBuilder::Sort(absl::Span operands, [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape( HloOpcode::kSort, operand_shape_ptrs)); - *instr.mutable_shape() = shape.ToProto(); - if (dimension == -1) { - TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0])); - dimension = keys_shape->rank() - 1; - } - instr.add_dimensions(dimension); - AddCalledComputation(comparator, &instr); - return AddInstruction(std::move(instr), HloOpcode::kSort, operands); + return SortInternal(shape, operands, comparator, dimension, is_stable); }); } +StatusOr XlaBuilder::SortInternal(const Shape& shape, + absl::Span operands, + const XlaComputation& comparator, + int64 dimension, bool is_stable) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + instr.set_is_stable(is_stable); + if (dimension == -1) { + TF_ASSIGN_OR_RETURN(const Shape* keys_shape, GetShapePtr(operands[0])); + dimension = keys_shape->rank() - 1; + } + instr.add_dimensions(dimension); + AddCalledComputation(comparator, &instr); + return AddInstruction(std::move(instr), HloOpcode::kSort, operands); +} + XlaOp XlaBuilder::ConvertElementType(XlaOp operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1739,16 +1829,21 @@ XlaOp XlaBuilder::ConvertElementType(XlaOp operand, XlaOp XlaBuilder::BitcastConvertType(XlaOp operand, PrimitiveType new_element_type) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape( *operand_shape, new_element_type)); - *instr.mutable_shape() = shape.ToProto(); - return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert, - {operand}); + return BitcastConvertTypeInternal(shape, operand); }); } +StatusOr XlaBuilder::BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert, + {operand}); +} + XlaOp XlaBuilder::Clamp(XlaOp min, XlaOp operand, XlaOp max) { return TernaryOp(HloOpcode::kClamp, min, operand, max); } @@ -1870,8 +1965,6 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, XlaOp init) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - // Infer shape. TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape()); TF_ASSIGN_OR_RETURN(const auto& condition_program_shape, @@ -1880,14 +1973,22 @@ XlaOp XlaBuilder::While(const XlaComputation& condition, TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape( condition_program_shape, body_program_shape, *init_shape)); - *instr.mutable_shape() = shape.ToProto(); - // Body comes before condition computation in the vector. - AddCalledComputation(body, &instr); - AddCalledComputation(condition, &instr); - return AddInstruction(std::move(instr), HloOpcode::kWhile, {init}); + return WhileInternal(shape, condition, body, init); }); } +StatusOr XlaBuilder::WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, + XlaOp init) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + // Body comes before condition computation in the vector. + AddCalledComputation(body, &instr); + AddCalledComputation(condition, &instr); + return AddInstruction(std::move(instr), HloOpcode::kWhile, {init}); +} + XlaOp XlaBuilder::Gather(XlaOp input, XlaOp start_indices, const GatherDimensionNumbers& dimension_numbers, absl::Span slice_sizes, @@ -2007,7 +2108,7 @@ XlaOp XlaBuilder::ConditionalImpl( std::vector branch_operand_shapes(branch_operands.size()); std::vector branch_computation_shapes( branch_computations.size()); - for (int j = 0; j < branch_operands.size(); ++j) { + for (int j = 0, end = branch_operands.size(); j < end; ++j) { TF_ASSIGN_OR_RETURN(branch_operand_shapes[j], GetShape(branch_operands[j])); TF_ASSIGN_OR_RETURN(branch_computation_shapes[j], @@ -2416,7 +2517,9 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension, if (layout) { TF_RET_CHECK(shape.IsTuple() && !ShapeUtil::IsNestedTuple(shape)); for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - if (layout->minor_to_major().size() != shape.tuple_shapes(i).rank()) { + const int64 layout_minor_to_major_size = + layout->minor_to_major().size(); + if (layout_minor_to_major_size != shape.tuple_shapes(i).rank()) { return InvalidArgument( "Provided layout must be compatible with the operand shape: %s " "vs %s", @@ -2800,6 +2903,196 @@ StatusOr XlaBuilder::IsConstant(XlaOp operand) const { return is_constant; } +StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { + TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, + LookUpInstruction(root_op)); + + HloComputationProto entry; + SetProtoIdAndName(&entry, StrCat(name_, "_dynamic_inference"), kNameSeparator, + GetNextId()); + ProgramShapeProto* program_shape = entry.mutable_program_shape(); + *program_shape->mutable_result() = + ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto(); + + std::set seen; + struct WorkItem { + explicit WorkItem(int64 handle, bool need_rewrite) + : handle(handle), need_rewrite(need_rewrite) {} + int64 handle; + // If need_rewrite is true, the instruction will be copied and rewrite into + // a pred instruction indicating if each value is dynamic. If need_rewrite + // is false, simply copy the instruction to the output graph. + // E.g., + // For select(P, A, B), we need to rewrite A and B into predicates, but + // don't need to rewrite P. + bool need_rewrite; + }; + std::queue worklist; + worklist.push(WorkItem(root->id(), true)); + entry.set_root_id(root->id()); + std::vector called_computatons; + // Rewritre instruction with id "from" into the new graph. + // Returns more work items that need to finish. + auto rewrite_instruction = + [&](int64 from, bool need_rewrite) -> StatusOr> { + // Rewrite the instruction with following rules: + // - Unary ops: Convert into bitcast (identity) with type Pred. + // - Binary ops: Convert into binary or. + // - Select: Convert into binary or with its two data operands. + // - Concat / Tuple/ GTE / Bitcast: Copy. + // - Param: Convert to constant True. + // - GetDimensionSize: Convert to constant True if dimension is dynamic, + // contant False if dimension is static. + // - Reduce: Convert to reduce or. + // - Constant: Convert to constant False. + // - Other ops: Not supported. + // Create the instruction for the new handle. + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, + LookUpInstructionByHandle(from)); + + TF_ASSIGN_OR_RETURN(HloOpcode opcode, + StringToHloOpcode(instr_proto->opcode())); + std::vector operands_todo; + auto* new_instr = entry.add_instructions(); + *new_instr = *instr_proto; + for (auto operand_id : new_instr->operand_ids()) { + operands_todo.emplace_back(operand_id, need_rewrite); + } + + if (!need_rewrite) { + *new_instr->mutable_name() = + GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id()); + return operands_todo; + } + *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape()); + Shape new_shape(new_instr->shape()); + switch (opcode) { + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kCos: + case HloOpcode::kClz: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: + case HloOpcode::kReal: + case HloOpcode::kRsqrt: + case HloOpcode::kLogistic: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kConvert: + case HloOpcode::kSqrt: + case HloOpcode::kCbrt: + case HloOpcode::kTanh: + CHECK_EQ(instr_proto->operand_ids_size(), 1); + *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kBitcast); + break; + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kDivide: + case HloOpcode::kComplex: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kCompare: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + CHECK_EQ(instr_proto->operand_ids_size(), 2); + *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); + break; + case HloOpcode::kSelect: + operands_todo[0].need_rewrite = false; + break; + case HloOpcode::kGather: + operands_todo[1].need_rewrite = false; + break; + case HloOpcode::kReduce: { + int64 reducer_id = new_instr->called_computation_ids(0); + called_computatons.push_back( + CreateReduceOr(reducer_id, &embedded_[reducer_id])); + break; + } + case HloOpcode::kTuple: + case HloOpcode::kTranspose: + case HloOpcode::kGetTupleElement: + case HloOpcode::kSlice: + case HloOpcode::kBroadcast: + case HloOpcode::kConcatenate: + case HloOpcode::kReshape: + break; + case HloOpcode::kGetDimensionSize: { + int64 dimension = instr_proto->dimensions(0); + int64 operand_handle = instr_proto->operand_ids(0); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, + LookUpInstructionByHandle(operand_handle)); + + *new_instr = CreateConstantInstruction( + from, new_shape, + operand_proto->shape().is_dynamic_dimension(dimension)); + operands_todo.clear(); + break; + } + case HloOpcode::kConstant: + *new_instr = CreateConstantInstruction(from, new_shape, false); + break; + case HloOpcode::kParameter: + *new_instr = CreateConstantInstruction(from, new_shape, true); + break; + default: + return InvalidArgument("Dynamic inferencing %s is not supported", + instr_proto->DebugString()); + } + *new_instr->mutable_name() = + GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id()); + return operands_todo; + }; + + while (!worklist.empty()) { + WorkItem item = worklist.front(); + worklist.pop(); + if (!seen.insert(item.handle).second) { + continue; + } + TF_ASSIGN_OR_RETURN(auto todos, + rewrite_instruction(item.handle, item.need_rewrite)); + for (WorkItem& todo : todos) { + worklist.push(todo); + } + } + absl::c_sort(*entry.mutable_instructions(), + [](const HloInstructionProto& p1, + const HloInstructionProto& p2) { return p1.id() < p2.id(); }); + XlaComputation computation(entry.id()); + HloModuleProto* module = computation.mutable_proto(); + module->set_name(entry.name()); + module->set_id(entry.id()); + module->set_entry_computation_name(entry.name()); + module->set_entry_computation_id(entry.id()); + *module->mutable_host_program_shape() = *program_shape; + for (auto& called_comp : called_computatons) { + *module->add_computations() = called_comp; + } + *module->add_computations() = std::move(entry); + XLA_VLOG_LINES(3, module->DebugString()); + return std::move(computation); +} + StatusOr XlaBuilder::BuildConstantSubGraph( XlaOp root_op, bool dynamic_dimension_is_minus_one) { TF_ASSIGN_OR_RETURN(bool is_constant, IsConstant(root_op)); @@ -3021,7 +3314,12 @@ StatusOr XlaBuilder::AddInstruction(HloInstructionProto&& instr, instr.add_operand_ids(operand.handle()); } - *instr.mutable_metadata() = metadata_; + if (one_shot_metadata_.has_value()) { + *instr.mutable_metadata() = one_shot_metadata_.value(); + one_shot_metadata_.reset(); + } else { + *instr.mutable_metadata() = metadata_; + } if (sharding_) { *instr.mutable_sharding() = *sharding_; } @@ -3227,31 +3525,71 @@ XlaOp Eq(const XlaOp lhs, const XlaOp rhs, return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq); } +XlaOp EqTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + auto compare_type = Comparison::Type::kFloatTotalOrder; + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kEq, + compare_type); +} + XlaOp Ne(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe); } +XlaOp NeTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + auto compare_type = Comparison::Type::kFloatTotalOrder; + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kNe, + compare_type); +} + XlaOp Ge(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe); } +XlaOp GeTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + auto compare_type = Comparison::Type::kFloatTotalOrder; + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGe, + compare_type); +} + XlaOp Gt(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt); } +XlaOp GtTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + auto compare_type = Comparison::Type::kFloatTotalOrder; + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kGt, + compare_type); +} + XlaOp Le(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe); } +XlaOp LeTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + auto compare_type = Comparison::Type::kFloatTotalOrder; + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLe, + compare_type); +} XlaOp Lt(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions) { return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt); } +XlaOp LtTotalOrder(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions) { + return Compare(lhs, rhs, broadcast_dimensions, ComparisonDirection::kLt, + Comparison::Type::kFloatTotalOrder); +} + XlaOp Compare(const XlaOp lhs, const XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction) { @@ -3259,6 +3597,13 @@ XlaOp Compare(const XlaOp lhs, const XlaOp rhs, broadcast_dimensions, direction); } +XlaOp Compare(const XlaOp lhs, const XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction, Comparison::Type compare_type) { + return lhs.builder()->BinaryOp(HloOpcode::kCompare, lhs, rhs, + broadcast_dimensions, direction, compare_type); +} + XlaOp Compare(const XlaOp lhs, const XlaOp rhs, ComparisonDirection direction) { return Compare(lhs, rhs, {}, direction); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 60bdc32e68d..6d30195d3d0 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -45,6 +46,16 @@ limitations under the License. namespace xla { class XlaBuilder; +class XlaOp; + +namespace internal { + +XlaOp XlaBuilderBuildFusion(XlaBuilder* builder, + absl::Span operands, + absl::string_view fusion_kind, + const XlaComputation& fused_computation); + +} // namespace internal // This represents an instruction that has been enqueued using the XlaBuilder. // This is used to pass to subsequent computations that depends upon the @@ -153,6 +164,11 @@ class XlaBuilder { // OpMetadata attached until a call to ClearOpMetadata. void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } + // Similar to SetOpMetadata, but only set the metadata for the next op. + void SetOneShotOpMetadata(OpMetadata metadata) { + metadata_ = std::move(metadata); + } + // Clears the HloMetadata state. void ClearOpMetadata() { metadata_.Clear(); } @@ -262,6 +278,31 @@ class XlaBuilder { StatusOr BuildConstantSubGraph( XlaOp root_op, bool dynamic_dimension_is_uint_max = false); + // Similar to BuildConstantSubGraph, but with root element type changed to + // boolean. A true value in the root indicates that the value is dynamic while + // false value indicates that the value is a constant. This will copy the + // needed ops/computations to the subgraph. + // + // E.g., + // Compuptation { + // a = 3 + // b = param(0) + // ROOT Tuple(a + b, a + 1, b + 1) + // } + // Calling BuildDynamicInferenceGraph on root will produce the following + // graph: + // + // Compuptation { + // a = False + // b = True + // ROOT Tuple(a | b, a, b) + // } + // + // The result, which is (True, False, True) after evaluation, can be + // interpreted as "First element is dynamic; Second element is static; Third + // element is dynamic". + StatusOr BuildDynamicInferenceGraph(XlaOp root_op); + // Returns the first error that was encountered while building the // computation. When an error is encountered, by default we return a vacuous // XlaOp and inform the user of the error that occurred while @@ -334,12 +375,16 @@ class XlaBuilder { // not available until the computation is built, and eventual error in the // arguments of this API will be detected only at computation Build() time. // - // Note: Aliasing API is 'may-alias' and only donated buffer at runtime will - // be aliased with output. If a buffer is not donated at runtime, a copy will - // be inserted by XLA to prevent buffer clobbering. + // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias' + // and only donated buffer at runtime will be aliased with output. If a buffer + // is not donated at runtime, a copy will be inserted by XLA to prevent buffer + // clobbering. void SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { - input_output_aliases_.push_back({output_index, param_number, param_index}); + const ShapeIndex& param_index, + HloInputOutputAliasConfig::AliasKind kind = + HloInputOutputAliasConfig::AliasKind::kMayAlias) { + input_output_aliases_.push_back( + {output_index, param_number, param_index, kind}); } // Describes an input/output alias as inserted by the SetUpAlias() API. @@ -350,6 +395,8 @@ class XlaBuilder { int64 param_number; // Specifies the index of the aliased buffer in the parameter ShapeIndex param_index; + // Specifies if the alias is a must alias or may alias. + HloInputOutputAliasConfig::AliasKind kind; }; // Looks up the HloInstruction and sets the frontend attribute "attribute" to @@ -624,6 +671,8 @@ class XlaBuilder { XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); + virtual StatusOr BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand); XlaOp Transpose(XlaOp operand, absl::Span permutation); virtual StatusOr TransposeInternal( @@ -635,6 +684,10 @@ class XlaBuilder { XlaOp Sort(absl::Span operands, const XlaComputation& comparator, int64 dimension = -1, bool is_stable = false); + virtual StatusOr SortInternal(const Shape& shape, + absl::Span operands, + const XlaComputation& comparator, + int64 dimension, bool is_stable); XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); @@ -651,6 +704,9 @@ class XlaBuilder { XlaOp While(const XlaComputation& condition, const XlaComputation& body, XlaOp init); + virtual StatusOr WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, XlaOp init); XlaOp Conditional(XlaOp predicate, XlaOp true_operand, const XlaComputation& true_computation, XlaOp false_operand, @@ -736,14 +792,17 @@ class XlaBuilder { // broadcast_dimensions specifies which dimensions to use for broadcasting // when the operation is between tensors of different ranks. The direction is // only used if opcode is kCompare. - XlaOp BinaryOp( - HloOpcode binop, XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - absl::optional direction = absl::nullopt); + XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + absl::optional direction = absl::nullopt, + absl::optional type = absl::nullopt); // Internal helper method for binary op compare without broadcast dimensions. virtual StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - Comparison::Direction direction); + ComparisonDirection direction); + virtual StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, + ComparisonDirection direction, + Comparison::Type type); // Internal helper method that does the building for an arbitrary binary op // with same ranked operands that doesn't broadcast. @@ -842,6 +901,9 @@ class XlaBuilder { // throughout the TensorFlow op kernel implementations). OpMetadata metadata_; + // A temporary metadata that will only be applied to the next op created. + absl::optional one_shot_metadata_; + // Sharding for this operator. This is structured as a "model"-like operation, // in order to simplify client code, similar to metadata_. absl::optional sharding_; @@ -906,22 +968,13 @@ class XlaBuilder { friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index); - friend XlaOp Eq(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Ne(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Ge(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Gt(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Lt(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Le(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); friend XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction); - friend XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction, + Comparison::Type compare_type); friend XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config); friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, @@ -1205,6 +1258,10 @@ class XlaBuilder { TF_RETURN_IF_ERROR(CheckOpBuilder(op)); return LookUpInstructionByHandleInternal(op.handle()); } + + friend XlaOp internal::XlaBuilderBuildFusion( + XlaBuilder* builder, absl::Span operands, + absl::string_view fusion_kind, const XlaComputation& fused_computation); }; // RAII-style object: sets the current sharding assignment in builder on @@ -1511,29 +1568,44 @@ XlaOp GetTupleElement(XlaOp tuple_data, int64 index); // Enqueues an equal-to comparison instruction onto the computation. XlaOp Eq(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a not-equal comparison instruction onto the computation. XlaOp Ne(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a greater-or-equal comparison instruction onto the computation. XlaOp Ge(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a greater-than comparison instruction onto the computation. XlaOp Gt(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a less-than comparison instruction onto the computation. XlaOp Lt(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a less-or-equal comparison instruction onto the computation. XlaOp Le(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions = {}); +XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); // Enqueues a comparison instruction onto the computation (optionally without // broadcast_dimensions for consistency with others). +XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction, Comparison::Type compare_type); XlaOp Compare(XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions, ComparisonDirection direction); diff --git a/tensorflow/compiler/xla/comparison_util.cc b/tensorflow/compiler/xla/comparison_util.cc index 47fb69e3bce..06dd9642cac 100644 --- a/tensorflow/compiler/xla/comparison_util.cc +++ b/tensorflow/compiler/xla/comparison_util.cc @@ -54,32 +54,59 @@ StatusOr StringToComparisonDirection( return it->second; } -Comparison::Comparison(Direction dir, PrimitiveType type) : dir_(dir) { +StatusOr StringToComparisonType( + absl::string_view compare_type_name) { + static auto* type_map = new absl::flat_hash_map({ + {"FLOAT", Comparison::Type::kFloat}, + {"TOTALORDER", Comparison::Type::kFloatTotalOrder}, + {"SIGNED", Comparison::Type::kSigned}, + {"UNSIGNED", Comparison::Type::kUnsigned}, + }); + auto it = type_map->find(compare_type_name); + if (it == type_map->end()) { + return InvalidArgument("Unknown comparison type: %s", compare_type_name); + } + return it->second; +} + +std::string ComparisonTypeToString(Comparison::Type type) { + switch (type) { + case Comparison::Type::kFloat: + return "FLOAT"; + case Comparison::Type::kFloatTotalOrder: + return "TOTALORDER"; + case Comparison::Type::kSigned: + return "SIGNED"; + case Comparison::Type::kUnsigned: + return "UNSIGNED"; + } +} + +Comparison::Comparison(Direction dir, PrimitiveType type) + : dir_(dir), type_(DefaultComparisonType(type)) {} + +Comparison::Type Comparison::DefaultComparisonType(PrimitiveType type) { switch (type) { case S8: case S16: case S32: case S64: - type_ = Type::kSigned; - break; + return Type::kSigned; case PRED: case U8: case U16: case U32: case U64: - type_ = Type::kUnsigned; - break; + return Type::kUnsigned; case F16: case F32: case BF16: case F64: case C64: case C128: - type_ = Type::kFloat; - break; + return Type::kFloat; default: LOG(FATAL) << "Unsupported comparison mode." - << ComparisonDirectionToString(dir) << ":" << PrimitiveType_Name(type) << "\n"; } } @@ -164,20 +191,6 @@ bool Comparison::IsAntireflexive() const { } } -/* static */ const char* Comparison::ComparisonTypeToString( - Comparison::Type type) { - switch (type) { - case Type::kFloat: - return "f"; - case Type::kFloatTotalOrder: - return "ft"; - case Type::kSigned: - return "s"; - case Type::kUnsigned: - return "u"; - } -} - std::string Comparison::ToString(std::string prefix1, std::string prefix2) const { return prefix1 + std::string(ComparisonDirectionToString(dir_)) + prefix2 + diff --git a/tensorflow/compiler/xla/comparison_util.h b/tensorflow/compiler/xla/comparison_util.h index 11335c6b5ba..33ae2c67106 100644 --- a/tensorflow/compiler/xla/comparison_util.h +++ b/tensorflow/compiler/xla/comparison_util.h @@ -103,11 +103,11 @@ class Comparison { bool Compare(const T a, const T b) const { return GetComparator()(a, b); } + static Type DefaultComparisonType(PrimitiveType t); private: static Direction Converse(Direction dir); static Direction Inverse(Direction dir); - static const char* ComparisonTypeToString(Type type); const Direction dir_; Type type_; @@ -117,10 +117,14 @@ inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) { return os << cmp.ToString(); } string ComparisonDirectionToString(Comparison::Direction direction); +std::string ComparisonTypeToString(Comparison::Type type); StatusOr StringToComparisonDirection( absl::string_view direction_name); +StatusOr StringToComparisonType( + absl::string_view compare_type_name); + using ComparisonDirection = Comparison::Direction; } // namespace xla diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 8ca6e2b294c..2dd7acb2f67 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -71,7 +71,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_allow_excess_precision(true); opts.set_xla_force_host_platform_device_count(1); opts.set_xla_gpu_deterministic_reductions(false); - opts.set_xla_cpu_enable_xprof_traceme(true); + opts.set_xla_cpu_enable_xprof_traceme(false); opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(false); return opts; @@ -535,10 +535,10 @@ static void AllocateFlags() { flag_values->xla_gpu_force_conv_nchw(), "For cuDNN convolutions, always NCHW layouts.")); flag_objects->push_back(tensorflow::Flag( - "xla_gpu_algorithm_blacklist_path", - string_setter_for(&DebugOptions::set_xla_gpu_algorithm_blacklist_path), - flag_values->xla_gpu_algorithm_blacklist_path(), - "An AlgorithmBlacklist text proto file as a blacklist of convolutions to " + "xla_gpu_algorithm_denylist_path", + string_setter_for(&DebugOptions::set_xla_gpu_algorithm_denylist_path), + flag_values->xla_gpu_algorithm_denylist_path(), + "An AlgorithmDenylist text proto file as a denylist of convolutions to " "avoid to use.")); flag_objects->push_back(tensorflow::Flag( "xla_gpu_deterministic_reductions", diff --git a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py index 212ad87d94c..16563bab5bc 100644 --- a/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py +++ b/tensorflow/compiler/xla/experimental/xla_sharding/xla_sharding.py @@ -294,3 +294,39 @@ def manual_to_auto_spmd_partition(tensor, manual_sharding, full_shape): """ return tf2xla.spmd_shard_to_full_shape( tensor, manual_sharding=manual_sharding, full_shape=full_shape) + + +def mesh_split(tensor, + device_mesh, + tensor_split_dims_mapping, + use_sharding_op=False): + """Returns a tensor that is split along multiple dimensions in a device mesh. + + Args: + tensor: A tf.Tensor to split. + device_mesh: An np.ndarray describing the topology of the device mesh and + each element is the ID of the device in the topology. + tensor_split_dims_mapping: A list of integers that map each tensor axis to + the device mesh axis along which it is sharded. Its length is the tensor + rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor + dimension i. Use -1 for tensor dimensions that are not sharded. + use_sharding_op: If true, adds a sharding op to set the sharding. + + Raises: + ValueError: The number of tensor split dimensions is different from device + mesh rank. + """ + permutation = [d for d in tensor_split_dims_mapping if d >= 0] + if len(permutation) != len(device_mesh.shape): + raise ValueError( + 'Number of tensor split dimensions (%r) is different from device mesh ' + 'rank (%r). tensor_split_dims_mapping: %r, device_mesh.shape: %r' % + (len(permutation), len( + device_mesh.shape), tensor_split_dims_mapping, device_mesh.shape)) + tile_assignment = _np.transpose(device_mesh, permutation) + tile_shape = [ + 1 if d < 0 else device_mesh.shape[d] for d in tensor_split_dims_mapping + ] + tile_assignment = _np.reshape(tile_assignment, tile_shape) + + return tile(tensor, tile_assignment, use_sharding_op=use_sharding_op) diff --git a/tensorflow/compiler/xla/g3doc/_book.yaml b/tensorflow/compiler/xla/g3doc/_book.yaml index e05f69b1e8b..8d217b89ae3 100644 --- a/tensorflow/compiler/xla/g3doc/_book.yaml +++ b/tensorflow/compiler/xla/g3doc/_book.yaml @@ -17,6 +17,8 @@ upper_tabs: path: /xla - title: XLA architecture path: /xla/architecture + - title: Known issues + path: /xla/known_issues - title: Broadcasting semantics path: /xla/broadcasting - title: Develop a new backend for XLA diff --git a/tensorflow/compiler/xla/g3doc/index.md b/tensorflow/compiler/xla/g3doc/index.md index 60bde306266..51d666fba9a 100644 --- a/tensorflow/compiler/xla/g3doc/index.md +++ b/tensorflow/compiler/xla/g3doc/index.md @@ -177,30 +177,6 @@ a bug to a single XLA program by using the [`replay_computation`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/tools/run_hlo_module_main.cc) and iteratively running it on generated programs. -## Known Issues - -Compilation with XLA can greatly improve the performance of your programs, but -the TensorFlow interop has a number of known sharp corners. - -### TensorArray TF/XLA Interconversion - -The problem manifests itself as an error message -`Support for TensorList crossing the XLA/TF boundary is not implemented`. - -XLA supports `tf.TensorArray`. However, the _interconversion_ between TF and -XLA representations is not implemented yet. -This error often arises when the `TensorArray` is used inside the compiled -block, but the derivative is taken outside. - -Workaround: compile the outermost scope which is taking the derivative. - -### Random Number Generation - -XLA currently ignores TF seeds to random operations. This affects stateful TF -random operations, such as `tf.random.normal`, or `tf.nn.dropout`. XLA will -behave as if the compilation was seeded with a new unique seed at each run. This -limitation does not apply to stateless random ops. - ## XLA Frontends Apart from TensorFlow, XLA programs can be generated by: diff --git a/tensorflow/compiler/xla/g3doc/known_issues.md b/tensorflow/compiler/xla/g3doc/known_issues.md new file mode 100644 index 00000000000..1c03c716a02 --- /dev/null +++ b/tensorflow/compiler/xla/g3doc/known_issues.md @@ -0,0 +1,32 @@ +# Known Issues + +Compilation with XLA can greatly improve the performance of your programs, but +the TensorFlow interop has a number of known sharp corners. + +## TensorArray TF/XLA interconversion + +The problem manifests itself as an error message +`Support for TensorList crossing the XLA/TF boundary is not implemented`. + +XLA supports `tf.TensorArray`. However, the _interconversion_ between TF and +XLA representations is not implemented yet. +This error often arises when the `TensorArray` is used inside the compiled +block, but the derivative is taken outside. + +Workaround: compile the outermost scope which is taking the derivative. + +## Dynamic `tf.TensorArray` is not supported + +Writes into `tf.TensorArray(..., dynamic_size=True)` are not compilable with +XLA, as such writes require an unknown number of reallocations when the array +exceeds the original bound. + +Workaround: provide a statically known bound to your arrays. + +## Random number generation + +XLA currently ignores TF seeds to random operations. This affects stateful TF +random operations, such as `tf.random.normal`, or `tf.nn.dropout`. XLA will +behave as if the compilation was seeded with a new unique seed at each run. This +limitation does not apply to stateless random ops. + diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 3031bfbf2e2..051c1539f6b 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1235,7 +1235,10 @@ floating-point types. Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge` (greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt` -(less-than). +(less-than). Another set of operators, EqTotalOrder, NeTotalOrder, GeTotalOrder, +GtTotalOrder, LeTotalOrder, and LtTotalOrder, provide the same functionalities, +except that they additionally support a total order over the floating point +numbers, by enforcing -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN. Arguments | Type | Semantics --------- | ------- | ---------------------------------------- diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 463a8d95fc5..4bec454e520 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -143,7 +143,8 @@ namespace xla { /* static */ bool IndexUtil::IndexInBounds(const Shape& shape, absl::Span index) { int64 rank = shape.rank(); - if (rank != index.size()) { + const int64 index_size = index.size(); + if (rank != index_size) { return false; } for (int64 d = 0; d < rank; ++d) { @@ -157,7 +158,8 @@ namespace xla { /* static */ int IndexUtil::CompareIndices(absl::Span lhs, absl::Span rhs) { int64 rank = lhs.size(); - CHECK_EQ(rhs.size(), rank); + const int64 rhs_rank = rhs.size(); + CHECK_EQ(rhs_rank, rank); for (int64 dim = 0; dim < rank; ++dim) { if (lhs[dim] < rhs[dim]) { return -1; diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index faa33e292c2..afd7141477f 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -342,7 +342,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ std::vector LayoutUtil::MakeLogicalToPhysical( const Layout& layout) { std::vector logical_to_physical(layout.minor_to_major_size()); - for (int64 physical = 0; physical < logical_to_physical.size(); ++physical) { + for (int64 physical = 0, end = logical_to_physical.size(); physical < end; + ++physical) { const int64 logical = Major(layout, physical); logical_to_physical[logical] = physical; } diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 73c37d6b2f3..d26e0881c53 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -48,13 +48,17 @@ namespace { using absl::StrCat; constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; +// Literals can be used as DMA targets, which can require alignment. We +// force a tensorflow::Allocator::kAllocatorAlignment-byte minimum +// alignment. +constexpr int kMinimumAlignment = 64; // Converts between little and big endian. // // Precondition: size % 2 == 0 (elements in the array are 16 bits long) void ConvertEndianShort(string* bytes) { CHECK_EQ(bytes->size() / 2, 0); - for (int64 i = 0; i < bytes->size(); i += 2) { + for (int64 i = 0, end = bytes->size(); i < end; i += 2) { std::swap((*bytes)[i], (*bytes)[i + 1]); } } @@ -133,12 +137,14 @@ void Literal::SetPiece(const Shape& shape, Piece* piece, bool allocate_arrays) { } } else if (shape.IsArray()) { if (allocate_arrays) { - // Literals can be used as DMA targets, which can require alignment. We - // force a tensorflow::Allocator::kAllocatorAlignment-byte minimum - // alignment. - constexpr int kMinimumAlignment = 64; piece->set_buffer(static_cast(tensorflow::port::AlignedMalloc( piece->size_bytes(), kMinimumAlignment))); + if (shape.is_dynamic()) { + CHECK_EQ(piece->dynamic_size_buffer(), nullptr); + piece->set_dynamic_size_buffer( + static_cast(tensorflow::port::AlignedMalloc( + piece->dynamic_size_buffer_bytes(), kMinimumAlignment))); + } } } else { // If the shape is neither an array nor tuple, then it must be @@ -171,6 +177,9 @@ void Literal::DeallocateBuffers() { if (piece->buffer() != nullptr) { tensorflow::port::AlignedFree(piece->buffer()); } + if (piece->dynamic_size_buffer() != nullptr) { + tensorflow::port::AlignedFree(piece->dynamic_size_buffer()); + } }); } @@ -199,6 +208,15 @@ Literal LiteralBase::CreateFromShape(const Shape& shape) { return literal; } +int32 LiteralBase::GetDynamicSize(int64 dim_index) const { + return GetDynamicSize(dim_index, {}); +} + +int32 LiteralBase::GetDynamicSize(int64 dim_index, + const ShapeIndex& shape_index) const { + return piece(shape_index).GetDynamicSize(dim_index); +} + absl::optional LiteralBase::GetFirstInteger() const { switch (shape().element_type()) { case U8: @@ -231,8 +249,10 @@ template Status MutableLiteralBase::CopySliceFromInternal( const LiteralBase& src_literal, absl::Span src_base, absl::Span dest_base, absl::Span copy_size) { - TF_RET_CHECK(src_literal.shape().rank() == src_base.size()); - TF_RET_CHECK(shape().rank() == dest_base.size()); + const int64 src_base_size = src_base.size(); + const int64 dest_base_size = dest_base.size(); + TF_RET_CHECK(src_literal.shape().rank() == src_base_size); + TF_RET_CHECK(shape().rank() == dest_base_size); auto linear_index = [](const Shape& shape, absl::Span multi_index) { @@ -381,7 +401,9 @@ std::vector Literal::DecomposeTuple() { // Move the respective buffer over to the element Literal. dest_piece->set_buffer(src_piece.buffer()); + dest_piece->set_dynamic_size_buffer(src_piece.dynamic_size_buffer()); src_piece.set_buffer(nullptr); + src_piece.set_dynamic_size_buffer(nullptr); }); } // Set this literal to be nil-shaped. @@ -407,23 +429,51 @@ void CopyElementsBetween(absl::Span dest, src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; } while (IndexUtil::BumpIndices(dest_shape, absl::MakeSpan(index))); } - } // namespace -Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { +int32 LiteralBase::Piece::GetDynamicSize(int64 dim_index) const { + CHECK(LayoutUtil::IsDenseArray(subshape())); + if (!subshape_->is_dynamic_dimension(dim_index)) { + // This is a static dimension, return size. + return subshape_->dimensions(dim_index); + } + CHECK_NE(dynamic_size_buffer(), nullptr); + return dynamic_size_buffer_[dim_index]; +} + +void LiteralBase::Piece::SetDynamicSize(int64 dim_index, int32 size) { + CHECK(LayoutUtil::IsDenseArray(subshape())); + CHECK(subshape_->is_dynamic_dimension(dim_index)); + if (dynamic_size_buffer() == nullptr) { + // Lazily initialize the dynamic size buffer. + set_dynamic_size_buffer(static_cast(tensorflow::port::AlignedMalloc( + dynamic_size_buffer_bytes(), kMinimumAlignment))); + /*for (int64 i = 0; i < subshape().rank(); ++i) { + // Initialized to -1 to help debug. + dynamic_size_buffer_[i] = -1; + }*/ + } + dynamic_size_buffer_[dim_index] = size; +} + +Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src, + bool only_dynamic_bound) { CHECK(subshape_ != nullptr); CHECK(src.subshape_ != nullptr); if (ShapeUtil::Equal(subshape(), src.subshape())) { // If the layouts are equal it's faster just to memcpy. memcpy(buffer(), src.buffer(), src.size_bytes()); } else { - TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); std::vector origin(subshape().rank(), 0); switch (subshape().element_type()) { -#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ - case (XLA_T): \ - CopyElementsBetween(data(), src.data(), \ - subshape(), src.subshape()); \ +#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ + case (XLA_T): \ + if (only_dynamic_bound) { \ + CopyElementsWithDynamicBound(src); \ + } else { \ + CopyElementsBetween(data(), src.data(), \ + subshape(), src.subshape()); \ + } \ break; COPY_ELEMENTS(U8, uint8); COPY_ELEMENTS(U16, uint16); @@ -447,21 +497,54 @@ Status LiteralBase::Piece::CopyFrom(const LiteralBase::Piece& src) { PrimitiveType_Name(subshape().element_type())); } } + DCHECK_EQ(dynamic_size_buffer_bytes(), src.dynamic_size_buffer_bytes()); + if (subshape().is_dynamic() && src.subshape().is_dynamic()) { + CHECK_NE(dynamic_size_buffer_, nullptr); + CHECK_NE(src.dynamic_size_buffer_, nullptr); + memcpy(dynamic_size_buffer(), src.dynamic_size_buffer(), + src.dynamic_size_buffer_bytes()); + } return Status::OK(); } +void MutableLiteralBase::SetDynamicSize(int64 dim_index, int32 size) { + return SetDynamicSize(dim_index, {}, size); +} + +void MutableLiteralBase::SetDynamicSize(int64 dim_index, + const ShapeIndex& shape_index, + int32 size) { + Shape* subshape_ = ShapeUtil::GetMutableSubshape(shape_.get(), shape_index); + CHECK_GE(subshape_->dimensions(dim_index), size); + if (subshape_->dimensions(dim_index) == size) { + subshape_->set_dynamic_dimension(dim_index, false); + return; + } + subshape_->set_dynamic_dimension(dim_index, true); + piece(shape_index).SetDynamicSize(dim_index, size); +} + Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, const ShapeIndex& dest_shape_index, - const ShapeIndex& src_shape_index) { + const ShapeIndex& src_shape_index, + bool only_dynamic_bound) { const Shape& dest_subshape = ShapeUtil::GetSubshape(shape(), dest_shape_index); const Shape& src_subshape = ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index); - if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { - return InvalidArgument( - "Destination subshape incompatible with source subshape: %s vs %s", - ShapeUtil::HumanString(dest_subshape), - ShapeUtil::HumanString(src_subshape)); + if (only_dynamic_bound) { + auto bound_shape = dest_subshape.is_static() ? src_subshape : dest_subshape; + auto compact_shape = + dest_subshape.is_static() ? dest_subshape : src_subshape; + CHECK(ShapeUtil::DynamicShapeIsCompatible(compact_shape, bound_shape)) + << compact_shape.ToString() << " vs " << bound_shape.ToString(); + } else { + if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { + return InvalidArgument( + "Destination subshape incompatible with source subshape: %s vs %s", + ShapeUtil::HumanString(dest_subshape), + ShapeUtil::HumanString(src_subshape)); + } } return root_piece_->ForEachMutableSubpieceWithStatus( [&](const ShapeIndex& index, Piece* piece) { @@ -483,10 +566,13 @@ Status MutableLiteralBase::CopyFrom(const LiteralSlice& src_literal, } // Construct the index of the corresponding piece in the source literal. ShapeIndex src_piece_index = src_shape_index; - for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + for (int64 i = dest_shape_index.size(), end = index.size(); i < end; + ++i) { src_piece_index.push_back(index[i]); } - TF_RETURN_IF_ERROR(piece->CopyFrom(src_literal.piece(src_piece_index))); + TF_RETURN_IF_ERROR( + piece->CopyFrom(src_literal.piece(src_piece_index), + /*only_dynamic_bound=*/only_dynamic_bound)); return Status::OK(); }); } @@ -514,7 +600,9 @@ Status Literal::MoveFrom(Literal&& src_literal, } Piece& dest_piece = piece(dest_index); tensorflow::port::AlignedFree(dest_piece.buffer()); + tensorflow::port::AlignedFree(dest_piece.dynamic_size_buffer()); dest_piece.set_buffer(src_piece.buffer()); + dest_piece.set_dynamic_size_buffer(src_piece.dynamic_size_buffer()); }); src_literal.shape_ = absl::make_unique(ShapeUtil::MakeNil()); @@ -629,13 +717,48 @@ Literal LiteralBase::Relayout(const Shape& shape_with_layout) const { return result; } +Literal LiteralBase::ToBoundedDynamic(const Shape& bounded_shape) const { + CHECK(bounded_shape.is_dynamic()); + Literal result(bounded_shape); + ShapeUtil::ForEachSubshape( + shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (!subshape.IsArray()) { + return; + } + for (int64 i = 0; i < subshape.rank(); ++i) { + result.SetDynamicSize(i, subshape.dimensions(i)); + } + }); + TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true)); + + return result; +} + +Literal LiteralBase::ToStatic() const { + // Create new shape with 'new_layout' set at the given shape index. + Shape new_shape = shape(); + ShapeUtil::ForEachMutableSubshape( + &new_shape, [this](Shape* subshape, const ShapeIndex& index) { + if (!subshape->IsArray()) { + return; + } + for (int64 i = 0; i < subshape->rank(); ++i) { + subshape->set_dynamic_dimension(i, false); + subshape->set_dimensions(i, GetDynamicSize(i, index)); + } + }); + Literal result(new_shape); + TF_CHECK_OK(result.CopyFrom(*this, {}, {}, /*only_dynamic_bound=*/true)); + return result; +} + StatusOr LiteralBase::Broadcast( const Shape& result_shape, absl::Span dimensions) const { if (!shape().IsArray()) { return InvalidArgument("Broadcast only supports arrays."); } - for (int64 i = 0; i < dimensions.size(); i++) { + for (int64 i = 0, end = dimensions.size(); i < end; i++) { TF_RET_CHECK(shape().dimensions(i) == result_shape.dimensions(dimensions[i])); } @@ -652,9 +775,14 @@ StatusOr LiteralBase::Broadcast( const int64 primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + for (int64 i = 0; i < dimensions.size(); ++i) { + int64 dynamic_size = GetDynamicSize(i); + result.SetDynamicSize(dimensions[i], dynamic_size); + } + ShapeUtil::ForEachIndex( result_shape, [&](absl::Span output_index) { - for (int64 i = 0; i < dimensions.size(); ++i) { + for (int64 i = 0, end = dimensions.size(); i < end; ++i) { scratch_source_index[i] = output_index[dimensions[i]]; } int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( @@ -674,6 +802,9 @@ StatusOr LiteralBase::Reshape( if (!shape().IsArray()) { return InvalidArgument("Reshape does not support tuples."); } + if (shape().is_dynamic()) { + return Unimplemented("Dynamic reshape is not implemented."); + } Literal output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { output = Relayout(LayoutUtil::GetDefaultLayoutForRank(shape().rank())); @@ -728,6 +859,9 @@ Literal LiteralBase::Transpose(absl::Span permutation) const { layout->add_minor_to_major(inverse_permutation[index]); } Literal new_literal(permuted_shape); + for (int64 i = 0; i < shape().rank(); i++) { + new_literal.SetDynamicSize(inverse_permutation[i], GetDynamicSize(i)); + } DCHECK_EQ(ShapeUtil::ByteSizeOf(new_literal.shape()), ShapeUtil::ByteSizeOf(shape())); std::memcpy(new_literal.untyped_data(), untyped_data(), size_bytes()); @@ -747,6 +881,14 @@ Literal LiteralBase::SliceInternal( return Get(new_indices); }) .ok()); + for (int64 dnum = 0; dnum < shape().rank(); ++dnum) { + if (shape().is_dynamic_dimension(dnum)) { + int64 dynamic_size = GetDynamicSize(dnum) - start_indices[dnum]; + CHECK_GE(dynamic_size, 0) << GetDynamicSize(dnum); + dynamic_size = std::min(dynamic_size, result_shape.dimensions(dnum)); + result_literal.SetDynamicSize(dnum, dynamic_size); + } + } return result_literal; } @@ -763,9 +905,10 @@ Literal LiteralBase::Slice(absl::Span start_indices, CHECK_GE(dimension, 0) << "dnum = " << dnum; result_dimensions.push_back(dimension); } - const auto result_shape = + auto result_shape = ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); + ShapeUtil::CopyDynamicDimensions(&result_shape, shape()); switch (result_shape.element_type()) { case PRED: return SliceInternal(result_shape, start_indices); @@ -861,14 +1004,20 @@ absl::optional LiteralBase::GetIntegralAsS64( switch (shape().element_type()) { case PRED: return Get(multi_index); + case S8: + return Get(multi_index); case U8: return Get(multi_index); + case S16: + return Get(multi_index); + case U16: + return Get(multi_index); case S32: return Get(multi_index); - case S64: - return Get(multi_index); case U32: return Get(multi_index); + case S64: + return Get(multi_index); case U64: return Get(multi_index); default: @@ -1045,8 +1194,9 @@ void DenseArrayToStringHelper(const LiteralBase& literal, } // Handle the non-innermost tensors of a 2D+ tensor. if (brace == "{") { + const int64 accum_indices_size = accum_indices->size(); if (rank > 3 && !accum_indices->empty() && - accum_indices->size() < rank) { + accum_indices_size < rank) { int index = accum_indices->size() - 1; int value = accum_indices->back(); return StrCat(brace, " /*i", index, "=", value, "*/\n"); @@ -1082,11 +1232,24 @@ void DenseArrayToStringHelper(const LiteralBase& literal, if (print_shape) { pieces->push_back(ShapeToString(print_layout, subshape)); + if (subshape.is_dynamic()) { + pieces->push_back("("); + for (int64 i = 0; i < subshape.dimensions_size(); ++i) { + pieces->push_back(StrCat(literal.GetDynamicSize(i, shape_index))); + if (i < subshape.dimensions_size() - 1) { + pieces->push_back(","); + } + } + pieces->push_back(")"); + } pieces->push_back(" "); } std::vector indices = {}; - std::vector dimensions(subshape.dimensions().begin(), - subshape.dimensions().end()); + std::vector dimensions; + dimensions.reserve(subshape.rank()); + for (int64 i = 0; i < subshape.rank(); ++i) { + dimensions.push_back(literal.GetDynamicSize(i, shape_index)); + } to_string_recursive(dimensions, &indices); } @@ -1367,20 +1530,51 @@ StatusOr LiteralBase::ConvertToShape(const Shape& dest_shape) const { } Literal literal(ShapeUtil::MakeTupleShape(element_shapes), /*allocate_arrays=*/false); - for (int i = 0; i < elements.size(); ++i) { + for (int i = 0, end = elements.size(); i < end; ++i) { TF_CHECK_OK( literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); } return literal; } +template +void LiteralBase::Piece::CopyElementsWithDynamicBound( + const LiteralBase::Piece& src) { + auto dest_shape = subshape(); + auto src_shape = src.subshape(); + + // At least one shape has to be static as bound. + CHECK(dest_shape.is_static() || src_shape.is_static()); + auto bound_shape = dest_shape.is_static() ? src_shape : dest_shape; + if (ShapeUtil::IsZeroElementArray(dest_shape)) { + return; + } + std::vector index(dest_shape.rank()); + do { + bool out_of_bound = false; + for (int64 i = 0; i < index.size(); ++i) { + // Do not copy elements beyond dynamic bound. + if (index[i] >= GetDynamicSize(i) || index[i] >= src.GetDynamicSize(i)) { + out_of_bound = true; + } + } + if (out_of_bound) { + continue; + } + data()[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, + index)] = + src.data()[IndexUtil::MultidimensionalIndexToLinearIndex( + src_shape, index)]; + } while (IndexUtil::BumpIndices(bound_shape, absl::MakeSpan(index))); +} + template bool LiteralBase::Piece::EqualElementsInternal( const LiteralBase::Piece& other, std::vector* multi_index) const { if (multi_index->size() == subshape().rank()) { return (Get(*multi_index) == other.Get(*multi_index)); } - for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { + for (int64 i = 0; i < GetDynamicSize(multi_index->size()); ++i) { multi_index->push_back(i); if (!EqualElementsInternal(other, multi_index)) { return false; @@ -1390,10 +1584,24 @@ bool LiteralBase::Piece::EqualElementsInternal( return true; } -bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { +bool LiteralBase::Piece::EqualDynamicSize( + const LiteralBase::Piece& other) const { DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); + if (subshape().is_static()) { + return true; + } - if (ShapeUtil::Equal(subshape(), other.subshape()) && + for (int64 i = 0; i < subshape().rank(); ++i) { + if (GetDynamicSize(i) != other.GetDynamicSize(i)) { + return false; + } + } + return true; +} + +bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { + if (subshape().is_static() && + ShapeUtil::Equal(subshape(), other.subshape()) && LayoutUtil::IsDenseArray(subshape())) { CHECK_EQ(size_bytes(), other.size_bytes()); return memcmp(buffer(), other.buffer(), size_bytes()) == 0; @@ -1403,14 +1611,16 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { switch (subshape().element_type()) { case PRED: return EqualElementsInternal(other, &multi_index); - case U8: - return EqualElementsInternal(other, &multi_index); + case S8: + return EqualElementsInternal(other, &multi_index); case S16: return EqualElementsInternal(other, &multi_index); case S32: return EqualElementsInternal(other, &multi_index); case S64: return EqualElementsInternal(other, &multi_index); + case U8: + return EqualElementsInternal(other, &multi_index); case U16: return EqualElementsInternal(other, &multi_index); case U32: @@ -1436,17 +1646,33 @@ bool LiteralBase::Piece::EqualElements(const LiteralBase::Piece& other) const { } bool LiteralBase::operator==(const LiteralBase& other) const { - if (!ShapeUtil::Compatible(shape(), other.shape())) { + // Checking the structure of tuple literals. Checks for dense arrays are + // performed below. + if (!ShapeUtil::EqualStructure(shape(), other.shape())) { return false; } return root_piece().ForEachSubpieceWithBool( [&](const ShapeIndex& index, const Piece& piece) { + const Piece& other_piece = other.piece(index); + const Shape& subshape = piece.subshape(); + const Shape& other_subshape = other_piece.subshape(); + if (subshape.element_type() != other_subshape.element_type()) { + return false; + } if (!piece.subshape().IsArray()) { return true; } + if (subshape.rank() != other_subshape.rank()) { + return false; + } + + for (int64 i = 0; i < subshape.rank(); ++i) { + if (piece.GetDynamicSize(i) != other_piece.GetDynamicSize(i)) { + return false; + } + } - const Piece& other_piece = other.piece(index); if (!piece.EqualElements(other_piece)) { return false; } @@ -1677,13 +1903,13 @@ bool LiteralBase::IsR1Iota() const { auto is_iota_at_idx = [&](const int64 idx) { switch (shape().element_type()) { case U8: - return Get({idx}) == idx; + return static_cast(Get({idx})) == idx; case U16: - return Get({idx}) == idx; + return static_cast(Get({idx})) == idx; case U32: - return Get({idx}) == idx; + return static_cast(Get({idx})) == idx; case U64: - return Get({idx}) == idx; + return static_cast(Get({idx})) == idx; case S8: return Get({idx}) == idx; case S16: @@ -1960,8 +2186,9 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { } case C128: { auto complex_data = data(); - TF_RET_CHECK(proto.c128s_size() == complex_data.size() * 2); - for (int64 i = 0; i < complex_data.size(); ++i) { + const int64 complex_data_size_doubled = complex_data.size() * 2; + TF_RET_CHECK(proto.c128s_size() == complex_data_size_doubled); + for (int64 i = 0, end = complex_data.size(); i < end; ++i) { complex_data[i] = complex128{proto.c128s(i * 2), proto.c128s(i * 2 + 1)}; } @@ -2035,6 +2262,7 @@ void MutableBorrowingLiteral::CopyPieceSubtree(const Shape& shape, } } else if (shape.IsArray()) { dest_piece->set_buffer(src_piece->buffer()); + dest_piece->set_dynamic_size_buffer(src_piece->dynamic_size_buffer()); } else { // If the shape is neither an array nor tuple, then it must be // zero-sized. Otherwise, some memory needs to be allocated for it. @@ -2179,7 +2407,7 @@ BorrowingLiteral::BorrowingLiteral(absl::Span src_buf_ptrs, root_piece_.set_subshape(shape_.get()); BuildPieceSubtree(*shape_, &root_piece_); - for (int i = 0; i < src_buf_ptrs.size(); ++i) { + for (int i = 0, end = src_buf_ptrs.size(); i < end; ++i) { const auto& src_shape = shape_->tuple_shapes(i); CHECK(src_shape.IsArray()); root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index a2be92fbf5b..1ee71618887 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -112,6 +112,10 @@ class LiteralBase { template NativeT Get(absl::Span multi_index) const; + // Get the dynamic size on dim_index in the literal at the given shape_index. + int32 GetDynamicSize(int64 dim_index, const ShapeIndex& shape_index) const; + int32 GetDynamicSize(int64 dim_index) const; + // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. template @@ -281,6 +285,18 @@ class LiteralBase { // than being limited to a single array within the shape. Literal Relayout(const Shape& shape_with_layout) const; + // Generate a new literal whose static sizes are equal to the previous + // literal's dynamic sizes. + Literal ToStatic() const; + + // Expand a static literal into a new one with a bounded dyanmic literal. The + // static dimensions of the original literal becomes dynamic dimensions of the + // new literal, where the argument `bounded_shape` becomes the bounded shape + // of the new literal. + // + // Precondition: bounded_shape.is_dynamic() + Literal ToBoundedDynamic(const Shape& bounded_shape) const; + // Creates a new literal by reshaping this literal to have the given // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. @@ -354,10 +370,22 @@ class LiteralBase { template void Set(absl::Span index, NativeT value); + int32 GetDynamicSize(int64 dim_index) const; + void SetDynamicSize(int64 dim_index, int32 size); // Gets/sets the buffer holding the array data. char* buffer() const { return buffer_; } void set_buffer(char* buffer) { buffer_ = buffer; } + // Gets/sets the buffer holding dynamic sizes. + int32* dynamic_size_buffer() const { return dynamic_size_buffer_; } + void set_dynamic_size_buffer(int32* dynamic_size_buffer) { + dynamic_size_buffer_ = dynamic_size_buffer; + } + + int64 dynamic_size_buffer_bytes() const { + return subshape().dimensions_size() * sizeof(int32); + } + // Gets or sets the subshape of this piece. This reference points to a // subshape within the shape in the containing Literal (Literal::shape_). const Shape& subshape() const { return *subshape_; } @@ -434,15 +462,21 @@ class LiteralBase { } // Returns true if this piece and 'other' contain the same data. This piece - // and 'other' must be array-shaped and compatible. + // and 'other' must be array-shaped and compatible. If a literal has dynamic + // shape, comparison is done only for the valid elements. bool EqualElements(const Piece& other) const; + // Returns true if this piece and other pieces have the same dynamic + // dimension sizes. + bool EqualDynamicSize(const Piece& other) const; + // Writes the shape and data (if array-shaped) into the given proto. void WriteToProto(LiteralProto* proto) const; // Copy the data from 'src' into this piece's buffer. Shapes of this piece - // and src must be compatible. - Status CopyFrom(const Piece& src); + // and src must be compatible. If only_dynamic_bound is true, only elements + // within dynamic bounds will be copied. + Status CopyFrom(const Piece& src, bool only_dynamic_bound); // Copies the data from the given proto into this piece. The shape of this // piece must be equal (not just compatible) to the shape of the proto. @@ -497,9 +531,15 @@ class LiteralBase { bool EqualElementsInternal(const Piece& other, std::vector* multi_index) const; + // Internal helper to copy elements from another given piece + template + void CopyElementsWithDynamicBound(const LiteralBase::Piece& src); + // For array-shaped pieces, this is the buffer holding the literal data. char* buffer_ = nullptr; + int32* dynamic_size_buffer_ = nullptr; + // The shape of piece. This points into the shape of the containing Literal // (Literal::shape_). const Shape* subshape_ = nullptr; @@ -550,6 +590,11 @@ class MutableLiteralBase : public LiteralBase { // mutate the shape as this can produce malformed Literals. Shape* mutable_shape_do_not_use() { return shape_.get(); } + // Set the dynamic size on dim_index in the literal at the given shape_index. + void SetDynamicSize(int64 dim_index, const ShapeIndex& shape_index, + int32 size); + void SetDynamicSize(int64 dim_index, int32 size); + // Returns a pointer to the underlying buffer holding the array at the given // shape index. CHECKs if the subshape of the literal at the given ShapeIndex // is not array. @@ -560,10 +605,12 @@ class MutableLiteralBase : public LiteralBase { // Copy values from 'src_literal' rooted at 'src_shape_index' into this // literal rooted at 'dest_shape_index'. The subshape of this literal rooted // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' - // rooted at 'src_shape_index', but need not be arrays. + // rooted at 'src_shape_index', but need not be arrays. If only_dynamic_bound + // is true, only elements within dynamic bounds will be copied. Status CopyFrom(const LiteralSlice& src_literal, const ShapeIndex& dest_shape_index = {}, - const ShapeIndex& src_shape_index = {}); + const ShapeIndex& src_shape_index = {}, + bool only_dynamic_bound = false); // Copies the values from src_literal, starting at src_base shape indexes, // to this literal, starting at dest_base, where the copy size in each @@ -924,9 +971,14 @@ void LiteralBase::EachCell( return; } std::vector indices(shape().rank(), 0); + + Shape shape_dynamic = shape(); + for (int64 i = 0; i < shape_dynamic.rank(); ++i) { + shape_dynamic.set_dimensions(i, GetDynamicSize(i)); + } do { per_cell(indices, Get(indices)); - } while (IndexUtil::BumpIndices(shape(), absl::MakeSpan(indices))); + } while (IndexUtil::BumpIndices(shape_dynamic, absl::MakeSpan(indices))); } template diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index e1f52f72e5d..155d281df0c 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -218,23 +218,12 @@ int64 RecursiveElementCount(const Shape& shape) { // Returns whether the given value is infinity. template bool IsInf(NativeT val) { - return std::isinf(val); + return Eigen::numext::isinf(val); } - -template <> -bool IsInf(half val) { - return std::isinf(static_cast(val)); -} - // Returns whether the given value is nan. template -float IsNan(NativeT value) { - return std::isnan(value); -} - -template <> -float IsNan(half value) { - return IsNan(static_cast(value)); +bool IsNan(NativeT value) { + return Eigen::numext::isnan(value); } // Converts the given floating-point value to a string. diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 37316a2a807..a58e450a55a 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -149,6 +149,16 @@ TEST_F(LiteralUtilTest, R2ToString) { EXPECT_EQ(expected, literal.ToString()); } +TEST_F(LiteralUtilTest, R2DynamicToString) { + auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + literal.SetDynamicSize(0, {}, 2); + const string expected = R"(s32[<=3,2](2,2) { + { 1, 2 }, + { 3, 4 } +})"; + EXPECT_EQ(expected, literal.ToString()); +} + TEST_F(LiteralUtilTest, R3ToString) { const auto literal = LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); @@ -421,6 +431,28 @@ TEST_F(LiteralUtilTest, TupleEquality) { EXPECT_NE(tuple1, different_tuple); } +TEST_F(LiteralUtilTest, DynamicShapeEquality) { + // Test equality with tuples. + auto r1 = LiteralUtil::CreateR1({1.0, 2.0}); + r1.SetDynamicSize(0, {}, 1); + auto r2 = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + r2.SetDynamicSize(0, {}, 1); + auto tuple1 = LiteralUtil::MakeTuple({&r1, &r2}); + + // Tuple with the same elements. One element is shared with the original + // tuple, the other is a clone of the element in the original tuple. + auto r1_clone = LiteralUtil::CreateR1({1.0, 3.0}); + r1_clone.SetDynamicSize(0, {}, 1); + auto tuple2 = LiteralUtil::MakeTuple({&r1_clone, &r2}); + EXPECT_EQ(tuple1, tuple2); + + // Tuple with different dynamic sizes. + auto r2_clone = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + r2_clone.SetDynamicSize(0, {}, 2); + auto tuple_3 = LiteralUtil::MakeTuple({&r1_clone, &r2_clone}); + EXPECT_NE(tuple1, tuple_3); +} + TEST_F(LiteralUtilTest, C64Equality) { // Test equality with tuples. auto vector = LiteralUtil::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); @@ -692,6 +724,47 @@ TEST_F(LiteralUtilTest, TransposeR4) { }); } +TEST_F(LiteralUtilTest, TransposeDynamicR2) { + // F32[2, <=3] (2, 1) + auto original = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}); + original.SetDynamicSize(1, 1); + // F32[<=3, 2] (1, 2) + auto reshape = original.Transpose(/*permutation=*/{1, 0}); + + reshape.EachCell([&](absl::Span indices, float value) { + EXPECT_EQ(value, original.Get({indices[1], indices[0]})); + }); +} + +TEST_F(LiteralUtilTest, ToStaticR2) { + // F32[2, <=3] (2, 1) + auto original = LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}); + original.SetDynamicSize(1, 1); + // F32[2, 1] + auto static_literal = original.ToStatic(); + EXPECT_EQ(static_literal.shape(), ShapeUtil::MakeShape(F32, {2, 1})); + EXPECT_TRUE(static_literal.shape().is_static()); + + static_literal.EachCell( + [&](absl::Span indices, float value) { + EXPECT_EQ(value, original.Get({indices[0], indices[1]})); + }); +} + +TEST_F(LiteralUtilTest, ToBoundedDynamicR2) { + // F32[2, 1] + auto original = LiteralUtil::CreateR2({{1}, {4}}); + // F32[2, <=3] (2, 1) + auto dynamic_shape = ShapeUtil::MakeShape(F32, {2, 3}, {false, true}); + auto dynamic_literal = original.ToBoundedDynamic(dynamic_shape); + EXPECT_EQ(dynamic_literal.shape(), dynamic_shape); + + dynamic_literal.EachCell( + [&](absl::Span indices, float value) { + EXPECT_EQ(value, original.Get({indices[0], indices[1]})); + }); +} + TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // Tests that using Relayout on an array is equivalent to creating it in the // target layout in the first place. @@ -797,6 +870,38 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) { EXPECT_EQ(input_2x3x2, result); } +TEST_F(LiteralUtilTest, SliceR2Dynamic) { + auto input_3x4 = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + input_3x4.SetDynamicSize(1, 3); + // slice second dim from dynamic size 3 to dynamic size 1. + auto result = input_3x4.Slice({0, 1}, {2, 2}); + auto expected = LiteralUtil::CreateR2({{2}, {6}}); + EXPECT_EQ(expected, result); + EXPECT_EQ(result.GetDynamicSize(1), 1); +} + +TEST_F(LiteralUtilTest, SliceR2DynamicInBound) { + auto input_3x4 = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + input_3x4.SetDynamicSize(1, 1); + auto result = input_3x4.Slice({0, 0}, {2, 2}); + auto expected = LiteralUtil::CreateR2({{1}, {5}}); + EXPECT_EQ(expected, result); + EXPECT_EQ(result.GetDynamicSize(1), 1); +} + +TEST_F(LiteralUtilTest, SliceR2DynamicOutOfBound) { + auto input_3x4 = LiteralUtil::CreateR2( + {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + input_3x4.SetDynamicSize(1, 1); + auto result = input_3x4.Slice({0, 1}, {2, 3}); + auto expected = LiteralUtil::CreateR2({{}, {}}); + EXPECT_EQ(expected, result); + // Out of bound access clamps into 0 sized dimension. + EXPECT_EQ(result.GetDynamicSize(1), 0); +} + TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); @@ -1510,7 +1615,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_u16) { EXPECT_EQ(u1, r[3]); } -TEST_F(LiteralUtilTest, LiteralSliceTest) { +TEST_F(LiteralUtilTest, LiteralDynamicSliceTest) { auto scalar = LiteralUtil::CreateR0(1.0); auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = LiteralUtil::MakeTuple({&scalar, &matrix}); @@ -1973,6 +2078,17 @@ TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { LiteralUtil::CreateR2({{9, 9}, {9, 9}})); } +TEST_F(LiteralUtilTest, DynamicBroadcast) { + Literal literal = LiteralUtil::CreateR1({1, 2}); + literal.SetDynamicSize(0, 1); + TF_ASSERT_OK_AND_ASSIGN( + Literal broadcasted_literal, + literal.Broadcast(/*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(broadcasted_literal, LiteralUtil::CreateR2({{1}, {1}})); + EXPECT_EQ(broadcasted_literal.GetDynamicSize(1), 1); +} + TEST_F(LiteralUtilTest, GetAsComplex128) { complex128 value = {1, 0}; Literal c1 = LiteralUtil::CreateR0(value); diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 4304c207cad..0286aa20b3b 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -67,7 +67,7 @@ Literal ConvertType(LiteralSlice literal) { primitive_util::NativeToPrimitiveType()) { auto src = literal.data(shape_index); auto dest = result.data(shape_index); - for (int64 i = 0; i < src.size(); ++i) { + for (int64 i = 0, end = src.size(); i < end; ++i) { dest[i] = static_cast(src[i]); } } else { @@ -329,7 +329,7 @@ Literal ConvertType(LiteralSlice literal) { /* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) { Literal literal(ShapeUtil::MakeShape(U8, {static_cast(value.size())})); - for (int i = 0; i < value.size(); ++i) { + for (int i = 0, end = value.size(); i < end; ++i) { literal.Set({i}, value[i]); } return literal; @@ -345,7 +345,7 @@ Literal ConvertType(LiteralSlice literal) { absl::Span new_dimensions, absl::Span minor_to_major, const LiteralSlice& literal) { int64 new_num_elements = 1; - for (int64 i = 0; i < new_dimensions.size(); ++i) { + for (int64 i = 0, end = new_dimensions.size(); i < end; ++i) { new_num_elements *= new_dimensions[i]; } CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); @@ -472,7 +472,7 @@ Literal ConvertType(LiteralSlice literal) { element_shapes.push_back(element->shape()); } Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); - for (int i = 0; i < elements.size(); ++i) { + for (int i = 0, end = elements.size(); i < end; ++i) { TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } return literal; @@ -485,7 +485,7 @@ Literal ConvertType(LiteralSlice literal) { element_shapes.push_back(element.shape()); } Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); - for (int i = 0; i < elements.size(); ++i) { + for (int i = 0, end = elements.size(); i < end; ++i) { TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i})); } return literal; @@ -499,7 +499,7 @@ Literal ConvertType(LiteralSlice literal) { element_shapes.push_back(element.shape()); } Literal literal(ShapeUtil::MakeTupleShape(element_shapes)); - for (int64 i = 0; i < elements.size(); ++i) { + for (int64 i = 0, end = elements.size(); i < end; ++i) { TF_CHECK_OK( literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); } diff --git a/tensorflow/compiler/xla/metric_table_report.cc b/tensorflow/compiler/xla/metric_table_report.cc index bad65ac3201..1a616341315 100644 --- a/tensorflow/compiler/xla/metric_table_report.cc +++ b/tensorflow/compiler/xla/metric_table_report.cc @@ -80,9 +80,11 @@ void MetricTableReport::WriteReportToInfoLog(double expected_metric_sum) { int64 pos = 0; const string report = MakeReport(expected_metric_sum); - while (pos < report.size()) { + const int report_size = report.size(); + while (pos < report_size) { int64 end_of_line = report.find('\n', pos); - if (end_of_line == string::npos) { + const int64 _npos = string::npos; + if (end_of_line == _npos) { end_of_line = report.size(); } absl::string_view line(report.data() + pos, end_of_line - pos); @@ -161,7 +163,8 @@ void MetricTableReport::AppendCategoryTable() { const char* const kIndentPrefix = " * "; int64 entries_to_show = std::min(max_entries_per_category_to_show_, category.entries.size()); - if (category.entries.size() == entries_to_show + 1) { + const int64 category_entries_size = category.entries.size(); + if (category_entries_size == entries_to_show + 1) { // May as well show the last entry on the line that would otherwise say // that there is a single entry not shown. ++entries_to_show; @@ -224,7 +227,8 @@ void MetricTableReport::AppendTableRow(const string& text, const double metric, // Don't try to make a gigantic string and crash if expected_metric_sum_ is // wrong somehow. int64 padding_len = 1; - if (max_metric_string_size >= metric_string.size()) { + const int64 metric_string_size = metric_string.size(); + if (max_metric_string_size >= metric_string_size) { padding_len += max_metric_string_size - metric_string.size(); } string padding(padding_len, ' '); @@ -254,7 +258,7 @@ string MetricTableReport::MetricString(double metric) { sp1.remove_prefix(1); } // Copy rest of input characters. - for (int64 i = 0; i < sp1.size(); ++i) { + for (int64 i = 0, end = sp1.size(); i < end; ++i) { if (i > 0 && (sp1.size() - i) % 3 == 0) { output.push_back(','); } diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index 6e61e0600a0..5b3b75eb352 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -59,6 +59,10 @@ cc_library( name = "tracked_device_buffer", srcs = ["tracked_device_buffer.cc"], hdrs = ["tracked_device_buffer.h"], + visibility = [ + "//learning/pathways/data_parallel:__pkg__", + "//tensorflow:internal", + ], deps = [ ":event_pool", ":local_device_state", diff --git a/tensorflow/compiler/xla/pjrt/distributed/client.cc b/tensorflow/compiler/xla/pjrt/distributed/client.cc index 830e512b156..43c0c7b277d 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/client.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/client.cc @@ -17,14 +17,16 @@ limitations under the License. #include // NOLINT +#include "absl/time/time.h" #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" #include "tensorflow/compiler/xla/pjrt/distributed/util.h" namespace xla { DistributedRuntimeClient::DistributedRuntimeClient( - std::shared_ptr<::grpc::Channel> channel) - : stub_(grpc::DistributedRuntimeService::NewStub(std::move(channel))) {} + std::shared_ptr<::grpc::Channel> channel, absl::Duration rpc_timeout) + : stub_(grpc::DistributedRuntimeService::NewStub(std::move(channel))), + rpc_timeout_(rpc_timeout) {} DistributedRuntimeClient::~DistributedRuntimeClient() = default; xla::Status DistributedRuntimeClient::Connect( @@ -35,6 +37,7 @@ xla::Status DistributedRuntimeClient::Connect( ctx.set_deadline(absl::ToChronoTime(absl::Now() + rpc_timeout_)); ConnectRequest request; request.set_protocol_version(kDistributedRuntimeProtocolVersion); + request.set_timeout_milliseconds(absl::ToInt64Milliseconds(rpc_timeout_)); *request.mutable_local_topology() = local_topology; VLOG(10) << "Connect: " << request.DebugString(); ConnectResponse response; diff --git a/tensorflow/compiler/xla/pjrt/distributed/client.h b/tensorflow/compiler/xla/pjrt/distributed/client.h index 865a752849e..049d76af4d6 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/client.h +++ b/tensorflow/compiler/xla/pjrt/distributed/client.h @@ -29,7 +29,10 @@ namespace xla { class DistributedRuntimeClient { public: - explicit DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel); + DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel, + absl::Duration rpc_timeout); + explicit DistributedRuntimeClient(std::shared_ptr<::grpc::Channel> channel) + : DistributedRuntimeClient(channel, absl::Seconds(120)) {} ~DistributedRuntimeClient(); xla::Status Connect(const LocalTopologyProto& local_topology, @@ -42,7 +45,7 @@ class DistributedRuntimeClient { private: const std::unique_ptr stub_; - const absl::Duration rpc_timeout_ = absl::Seconds(120); + const absl::Duration rpc_timeout_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/distributed/protocol.h b/tensorflow/compiler/xla/pjrt/distributed/protocol.h index 4daa939ac8d..e8be43006f7 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/protocol.h +++ b/tensorflow/compiler/xla/pjrt/distributed/protocol.h @@ -18,7 +18,7 @@ limitations under the License. namespace xla { -static constexpr int kDistributedRuntimeProtocolVersion = 1; +static constexpr int kDistributedRuntimeProtocolVersion = 2; } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/distributed/protocol.proto b/tensorflow/compiler/xla/pjrt/distributed/protocol.proto index 18bfa221110..c3bbb3a7f5d 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/protocol.proto +++ b/tensorflow/compiler/xla/pjrt/distributed/protocol.proto @@ -61,6 +61,7 @@ message ConnectRequest { int32 protocol_version = 1; // Always 1 at present. LocalTopologyProto local_topology = 2; + int32 timeout_milliseconds = 3; } message ConnectResponse { diff --git a/tensorflow/compiler/xla/pjrt/distributed/service.cc b/tensorflow/compiler/xla/pjrt/distributed/service.cc index 3325fcd8319..868529637de 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/service.cc +++ b/tensorflow/compiler/xla/pjrt/distributed/service.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/distributed/service.h" +#include "absl/time/time.h" #include "tensorflow/compiler/xla/pjrt/distributed/protocol.h" #include "tensorflow/compiler/xla/pjrt/distributed/util.h" #include "tensorflow/compiler/xla/status.h" @@ -69,11 +70,12 @@ void BuildGlobalTopology(absl::Span local_topologies, mu_.AssertHeld(); return num_nodes_present_ == nodes_.size(); }; + auto connect_timeout = absl::Milliseconds(request->timeout_milliseconds()); if (!mu_.AwaitWithTimeout(absl::Condition(&all_nodes_present), - kConnectTimeout)) { + connect_timeout)) { return ToGrpcStatus(tensorflow::errors::DeadlineExceeded( "Timed out after %s waiting for all nodes to call Connect()", - absl::FormatDuration(kConnectTimeout))); + absl::FormatDuration(connect_timeout))); } if (node_id == 0) { diff --git a/tensorflow/compiler/xla/pjrt/distributed/service.h b/tensorflow/compiler/xla/pjrt/distributed/service.h index 9ecbdb3cc7c..fe323d9f3b2 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/service.h +++ b/tensorflow/compiler/xla/pjrt/distributed/service.h @@ -50,8 +50,6 @@ class DistributedRuntimeServiceImpl final KeyValueSetResponse* response) override; private: - const absl::Duration kConnectTimeout = absl::Seconds(120); - absl::Mutex mu_; enum class State { kInitializing, kRunning }; State state_ ABSL_GUARDED_BY(mu_) = State::kInitializing; diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 7e0d0159f4b..c5dce4a37f7 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -1004,7 +1004,7 @@ PjRtBuffer::GetBufferForHoldLocked(ScopedHold::Type type) { // acquiring any other kind of hold. WaitForOutstandingDonationHold(); if (device_buffer_ == nullptr) { - return InvalidArgument("Hold requested on invalid buffer"); + return InvalidArgument("Hold requested on deleted or donated buffer"); } else { ++holds_[type]; } @@ -1084,7 +1084,8 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy, // We can't perform any other action while a donation hold is in progress. WaitForOutstandingDonationHold(); if (device_buffer_ == nullptr) { - return InvalidArgument("CopyToHostAsync() called on invalid buffer."); + return InvalidArgument( + "CopyToHostAsync() called on deleted or donated buffer"); } if (discard_cached_copy) { auto it = host_values_.find(host_layout); @@ -1154,7 +1155,7 @@ StatusOr> PjRtBuffer::ToLiteral( TF_ASSIGN_OR_RETURN(std::shared_ptr host_value, CopyToHostAsyncInternal(discard_cached_copy, layout)); if (host_value == nullptr) { - return InvalidArgument("ToLiteral called on invalid buffer"); + return InvalidArgument("ToLiteral called on deleted or donated buffer"); } host_value->ready.WaitForNotification(); TF_RETURN_IF_ERROR(host_value->status); @@ -1272,7 +1273,8 @@ StatusOr> PjRtBuffer::CopyToDevice( // We can't perform any other action while a donation hold is in progress. WaitForOutstandingDonationHold(); if (device_buffer_ == nullptr) { - return InvalidArgument("CopyToDevice called on invalid buffer"); + return InvalidArgument( + "CopyToDevice called on deleted or donated buffer"); } AcquireHoldLocked(&src_device_buffer); } @@ -1313,7 +1315,8 @@ Status PjRtBuffer::BlockHostUntilReady() { { absl::MutexLock lock(&mu_); if (device_buffer_ == nullptr) { - return InvalidArgument("BlockHostUntilReady() called on invalid buffer."); + return InvalidArgument( + "BlockHostUntilReady() called on deleted or donated buffer"); } device_buffer = device_buffer_; } @@ -1383,7 +1386,7 @@ StatusOr MakeTupleHelper( local_device->compute_stream()->parent(), root_table_memory.cref())); } - ExecutionInput execution_input(on_device_shape); + ExecutionInput execution_input(on_device_shape, on_host_shape); ShapeTree::iterator input_iterator = execution_input.MutableBuffers()->begin(); ShapeTree::iterator iterator_end = @@ -1521,7 +1524,6 @@ StatusOr PjRtExecutable::EnqueueExecution( << " mapped to device ordinal for execution: " << device_ordinal; absl::flat_hash_set events; - std::vector argument_host_shapes; std::vector execution_inputs; device_buffers->reserve(argument_handles.size()); const absl::flat_hash_set& parameters_that_must_be_donated = @@ -1570,24 +1572,22 @@ StatusOr PjRtExecutable::EnqueueExecution( } LocalDeviceState* device_state = &client_->device_state(device_ordinal); - TupleHandle tuple_handle; + absl::optional tuple_handle; if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) { TF_ASSIGN_OR_RETURN(tuple_handle, MakeTupleHelper(client_, device_state, argument_handles, *device_buffers, device_ordinal)); - events.insert(tuple_handle.event.get()); - execution_inputs.emplace_back(std::move(tuple_handle.execution_input)); - argument_host_shapes.push_back(&tuple_handle.on_host_shape); + events.insert(tuple_handle->event.get()); + execution_inputs.emplace_back(std::move(tuple_handle->execution_input)); } else { - argument_host_shapes.reserve(argument_handles.size()); execution_inputs.reserve(argument_handles.size()); for (int i = 0; i < argument_handles.size(); ++i) { PjRtBuffer* handle = argument_handles[i]; - argument_host_shapes.push_back(&handle->on_host_shape()); const PjRtBuffer::ScopedHold& device_buffer = (*device_buffers)[i]; // Make an ExecutionInput from the device buffer. - execution_inputs.emplace_back(handle->on_device_shape()); + execution_inputs.emplace_back(handle->on_device_shape(), + handle->on_host_shape()); ExecutionInput& execution_input = execution_inputs.back(); ShapeTree::iterator input_iterator = execution_input.MutableBuffers()->begin(); @@ -1613,6 +1613,10 @@ StatusOr PjRtExecutable::EnqueueExecution( run_options.set_run_id(run_id); run_options.set_rng_seed(device_state->GetNewPrngSeed()); run_options.set_gpu_executable_run_options(client_->gpu_run_options()); + run_options.set_launch_id(options.launch_id); + if (run_options.launch_id() != 0) { + VLOG(1) << "launch id for " << name() << ": " << run_options.launch_id(); + } // The choice of where we wait is arbitrary; the reason for the wait is // pacing to avoid problems such as memory fragmentation and running ahead @@ -1623,8 +1627,8 @@ StatusOr PjRtExecutable::EnqueueExecution( device_state->compute_semaphore().ScopedAcquire(1)); StatusOr result_buffer_or_status = - executables_[executable_idx]->RunAsync( - argument_host_shapes, std::move(execution_inputs), run_options); + executables_[executable_idx]->RunAsync(std::move(execution_inputs), + run_options); VLOG(1) << "Replica " << replica << " partition " << partition << " completed; ok=" << result_buffer_or_status.ok(); @@ -2141,13 +2145,13 @@ StatusOr, Shape>> GetShardedProgramShapes( client->client()->Compile(computation, argument_layout_pointers, build_options)); - auto py_executable = absl::make_unique( + auto executable = absl::make_unique( std::move(local_executables), options.parameter_is_tupled_arguments, std::move(device_assignment), std::move(local_logical_device_ids), std::move(local_devices), client); - TF_RETURN_IF_ERROR(py_executable->SetUpDonation( - client, options.parameter_is_tupled_arguments)); - return py_executable; + TF_RETURN_IF_ERROR( + executable->SetUpDonation(client, options.parameter_is_tupled_arguments)); + return executable; } } // namespace xla diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index b234027adf3..bb9093a8bf7 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -119,6 +119,8 @@ struct PjRtCrossHostRecvBuffer { using PjRtCrossHostRecvNotifier = std::function>&&)>; +class PjRtExecutable; + // Encapsulates the state of Python session with XLA. // // It is the responsibility of the client of this API to keep the PjRtClient @@ -181,6 +183,13 @@ class PjRtClient { virtual StatusOr> GetParametersThatMustBeDonated( const LocalExecutable& executable, bool tuple_inputs) const; + // Generates a unique fingerprint for `executable`. See + // PjRtExecutable::fingerprint_. + virtual StatusOr> ExecutableFingerprint( + const PjRtExecutable& executable) const { + return absl::optional(); + } + protected: friend class PjRtBuffer; virtual void EnqueueCrossHostReceive( @@ -668,6 +677,11 @@ struct ExecuteOptions { // If true, the computation must return a tuple, which will be destructured // into its elements. bool untuple_result = false; + // If non-zero, identifies this execution as part of a potentially + // multi-device launch. This can be used to detect scheduling errors, e.g. if + // multi-host programs are launched in different orders on different hosts, + // the launch IDs may be used by the runtime to detect the mismatch. + int32 launch_id = 0; }; // Represents a compiled computation that can be executed given handles to @@ -687,6 +701,8 @@ class PjRtExecutable { std::vector> local_logical_device_ids, std::vector local_devices, PjRtClient* client); + virtual ~PjRtExecutable() = default; + PjRtClient* client() const { return client_; } int num_replicas() const { @@ -744,12 +760,14 @@ class PjRtExecutable { // Initializes information about which arguments to which executables must be // donated due to aliases that were specified by the computation. Status SetUpDonation(PjRtClient* client, bool tuple_inputs); + StatusOr EnqueueExecution( absl::Span argument_handles, int replica, int partition, int executable_idx, const RunId& run_id, const ExecuteOptions& options, Device* device, std::vector* device_buffers, std::shared_ptr device_assignment) const; + StatusOr>> ExecuteHelper( absl::Span argument_handles, int replica, int partition, const RunId& run_id, const ExecuteOptions& options, diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 2143d1dfbe7..c932469c56a 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -112,6 +112,21 @@ xla::PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth) { } } +xla::PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth) { + switch (src_bitwidth) { + case 8: + return xla::S8; + case 16: + return xla::S16; + case 32: + return xla::S32; + case 64: + return xla::S64; + default: + return xla::PRIMITIVE_TYPE_INVALID; + } +} + PrimitiveType ComplexComponentType(PrimitiveType complex_type) { switch (complex_type) { case C64: diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 034c14e8930..1228b4f9a32 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -153,6 +153,8 @@ int BitWidth(PrimitiveType type); PrimitiveType UnsignedIntegralTypeForBitWidth(int64 src_bitwidth); +PrimitiveType SignedIntegralTypeForBitWidth(int64 src_bitwidth); + // Returns the real, imag component type underlying the given complex type. // LOG(FATAL)'s if complex_type is not complex. PrimitiveType ComplexComponentType(PrimitiveType complex_type); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 94e345091eb..aa55a39218d 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -202,6 +202,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/core/platform:fingerprint", "//tensorflow/core/profiler:protos_all_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -400,7 +401,7 @@ pybind_extension( "//tensorflow/core:lib_internal_impl", # buildcleaner: keep "//tensorflow/core/profiler/lib:profiler_backends", "//tensorflow/core/profiler/lib:profiler_session", - "//tensorflow/core/profiler/rpc:profiler_server", + "//tensorflow/core/profiler/rpc:profiler_server_impl", "//tensorflow/python/profiler/internal:traceme_wrapper", "//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:platform", diff --git a/tensorflow/compiler/xla/python/bfloat16.cc b/tensorflow/compiler/xla/python/bfloat16.cc index 6d80a60550b..1f21b3fb242 100644 --- a/tensorflow/compiler/xla/python/bfloat16.cc +++ b/tensorflow/compiler/xla/python/bfloat16.cc @@ -455,10 +455,10 @@ int NPyBfloat16_Compare(const void* a, const void* b, void* arr) { return 1; } // NaNs sort to the end. - if (!std::isnan(x) && std::isnan(y)) { + if (!Eigen::numext::isnan(x) && Eigen::numext::isnan(y)) { return -1; } - if (std::isnan(x) && !std::isnan(y)) { + if (Eigen::numext::isnan(x) && !Eigen::numext::isnan(y)) { return 1; } return 0; @@ -962,7 +962,7 @@ struct Frexp { struct Heaviside { bfloat16 operator()(bfloat16 bx, bfloat16 h0) { float x = static_cast(bx); - if (std::isnan(x)) { + if (Eigen::numext::isnan(x)) { return bx; } if (x < 0) { @@ -984,7 +984,9 @@ struct IsInf { bool operator()(bfloat16 a) { return std::isinf(static_cast(a)); } }; struct IsNan { - bool operator()(bfloat16 a) { return std::isnan(static_cast(a)); } + bool operator()(bfloat16 a) { + return Eigen::numext::isnan(static_cast(a)); + } }; struct Ldexp { bfloat16 operator()(bfloat16 a, int exp) { @@ -1200,25 +1202,25 @@ struct Ge { struct Maximum { bfloat16 operator()(bfloat16 a, bfloat16 b) { float fa(a), fb(b); - return std::isnan(fa) || fa > fb ? a : b; + return Eigen::numext::isnan(fa) || fa > fb ? a : b; } }; struct Minimum { bfloat16 operator()(bfloat16 a, bfloat16 b) { float fa(a), fb(b); - return std::isnan(fa) || fa < fb ? a : b; + return Eigen::numext::isnan(fa) || fa < fb ? a : b; } }; struct Fmax { bfloat16 operator()(bfloat16 a, bfloat16 b) { float fa(a), fb(b); - return std::isnan(fb) || fa > fb ? a : b; + return Eigen::numext::isnan(fb) || fa > fb ? a : b; } }; struct Fmin { bfloat16 operator()(bfloat16 a, bfloat16 b) { float fa(a), fb(b); - return std::isnan(fb) || fa < fb ? a : b; + return Eigen::numext::isnan(fb) || fa < fb ? a : b; } }; @@ -1244,7 +1246,8 @@ struct NextAfter { float from_as_float(from), to_as_float(to); memcpy(&from_as_int, &from, sizeof(bfloat16)); memcpy(&to_as_int, &to, sizeof(bfloat16)); - if (std::isnan(from_as_float) || std::isnan(to_as_float)) { + if (Eigen::numext::isnan(from_as_float) || + Eigen::numext::isnan(to_as_float)) { return bfloat16(std::numeric_limits::quiet_NaN()); } if (from_as_int == to_as_int) { diff --git a/tensorflow/compiler/xla/python/ops.cc b/tensorflow/compiler/xla/python/ops.cc index 9362a367dfc..3ac4709b160 100644 --- a/tensorflow/compiler/xla/python/ops.cc +++ b/tensorflow/compiler/xla/python/ops.cc @@ -114,24 +114,26 @@ void BuildOpsSubmodule(py::module* m) { "CustomCall", [](XlaBuilder* builder, const py::bytes& call_target_name, absl::Span operands, const Shape& shape, - const py::bytes& opaque) -> XlaOp { - return CustomCall(builder, call_target_name, operands, shape, opaque); + const py::bytes& opaque, bool has_side_effect) -> XlaOp { + return CustomCall(builder, call_target_name, operands, shape, opaque, + has_side_effect); }, py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), - py::arg("shape"), py::arg("opaque") = py::bytes("")); + py::arg("shape"), py::arg("opaque") = py::bytes(""), + py::arg("has_side_effect") = false); ops.def( "CustomCallWithLayout", [](XlaBuilder* builder, const py::bytes& call_target_name, absl::Span operands, const Shape& shape_with_layout, absl::Span operand_shapes_with_layout, - const py::bytes& opaque) -> XlaOp { - return CustomCallWithLayout(builder, call_target_name, operands, - shape_with_layout, - operand_shapes_with_layout, opaque); + const py::bytes& opaque, bool has_side_effect) -> XlaOp { + return CustomCallWithLayout( + builder, call_target_name, operands, shape_with_layout, + operand_shapes_with_layout, opaque, has_side_effect); }, py::arg("builder"), py::arg("call_target_name"), py::arg("operands"), py::arg("shape_with_layout"), py::arg("operand_shapes_with_layout"), - py::arg("opaque") = py::bytes("")); + py::arg("opaque") = py::bytes(""), py::arg("has_side_effect") = false); ops.def("Dot", &Dot, py::arg("lhs"), py::arg("rhs"), py::arg("precision_config") = nullptr); ops.def("DotGeneral", &DotGeneral, py::arg("lhs"), py::arg("rhs"), diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc index bc7244cfc64..9b95f8e03de 100644 --- a/tensorflow/compiler/xla/python/py_client.cc +++ b/tensorflow/compiler/xla/python/py_client.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/python/py_client.h" +#include + #include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/python/py_buffer.h" #include "tensorflow/compiler/xla/python/py_executable.h" @@ -83,7 +85,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) { return result; } -StatusOr> PyClient::BufferFromPyal( +StatusOr> PyClient::BufferFromPyval( const pybind11::object& argument, Device* device, bool force_copy, PjRtBuffer::HostBufferSemantics host_buffer_semantics) { if (device == nullptr) { @@ -104,7 +106,6 @@ StatusOr> PyClient::BufferFromPyal( return InvalidArgument("from_python argument must be an array."); } - TF_ASSIGN_OR_RETURN(PythonBufferTree tree, GetPythonBufferTree(argument)); std::shared_ptr py_buffer_ref = GlobalPyRefManager()->ManageReference(std::move(c->array)); @@ -121,18 +122,22 @@ StatusOr> PyClient::BufferFromPyal( std::move(traceback)); } -StatusOr> PyClient::Compile( +StatusOr> PyClient::Compile( const XlaComputation& computation, CompileOptions options) { std::unique_ptr executable; + absl::optional fingerprint; { py::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(executable, PjRtExecutable::Compile(computation, pjrt_client_.get(), std::move(options))); + TF_ASSIGN_OR_RETURN(fingerprint, + pjrt_client_->ExecutableFingerprint(*executable)); } auto traceback = Traceback::Get(); - return std::make_unique( - shared_from_this(), std::move(executable), std::move(traceback)); + return std::make_shared( + shared_from_this(), std::move(executable), std::move(traceback), + std::move(fingerprint)); } class ProfileBuilder { @@ -275,7 +280,8 @@ py::bytes PyClient::HeapProfile() { kind_label->set_str(buffer_string_id); auto* device_label = sample->add_label(); device_label->set_key(device_string_id); - device_label->set_num(entry.first.device->id()); + device_label->set_str( + builder.StringId(entry.first.device->DebugString())); } else { kind_label->set_str(executable_string_id); } diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index be61bd74419..e41415c42f2 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -120,11 +120,11 @@ class PyClient : public std::enable_shared_from_this { return pjrt_client_->client()->CreateHostToDeviceChannelHandle(); } - StatusOr> BufferFromPyal( + StatusOr> BufferFromPyval( const pybind11::object& argument, Device* device, bool force_copy, PjRtBuffer::HostBufferSemantics host_buffer_semantics); - StatusOr> Compile( + StatusOr> Compile( const XlaComputation& computation, CompileOptions options); pybind11::bytes HeapProfile(); diff --git a/tensorflow/compiler/xla/python/py_executable.cc b/tensorflow/compiler/xla/python/py_executable.cc index c56fd3a89fc..ed524f1cb33 100644 --- a/tensorflow/compiler/xla/python/py_executable.cc +++ b/tensorflow/compiler/xla/python/py_executable.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/python/py_executable.h" #include "absl/algorithm/container.h" +#include "tensorflow/core/platform/fingerprint.h" namespace xla { @@ -23,10 +24,12 @@ namespace py = pybind11; PyExecutable::PyExecutable(std::shared_ptr client, std::unique_ptr executable, - std::shared_ptr traceback) + std::shared_ptr traceback, + absl::optional fingerprint) : client_(std::move(client)), executable_(std::move(executable)), - traceback_(std::move(traceback)) { + traceback_(std::move(traceback)), + fingerprint_(std::move(fingerprint)) { CHECK(PyGILState_Check()); next_ = client_->executables_; client_->executables_ = this; @@ -34,6 +37,12 @@ PyExecutable::PyExecutable(std::shared_ptr client, if (next_) { next_->prev_ = this; } + options_.untuple_result = true; + if (fingerprint_) { + options_.launch_id = tensorflow::Fingerprint32(*fingerprint_); + VLOG(1) << "Fingerprint for executable " << executable_->name() << ": " + << *fingerprint_; + } } PyExecutable::~PyExecutable() { @@ -58,18 +67,33 @@ std::vector> PyExecutable::LocalDevices() const { return devices; } +StatusOr>> PyExecutable::PjRtExecute( + absl::Span args) { + std::vector> output_buffers; + { + py::gil_scoped_release gil_release; + TF_ASSIGN_OR_RETURN(output_buffers, executable_->Execute(args, options_)); + } + auto traceback = Traceback::Get(); + std::vector> outputs; + outputs.reserve(output_buffers.size()); + for (auto& buffer : output_buffers) { + outputs.push_back( + std::make_unique(client_, std::move(buffer), traceback)); + } + return outputs; +} + StatusOr>> PyExecutable::Execute( absl::Span args) { std::vector> output_buffers; { py::gil_scoped_release gil_release; - ExecuteOptions options; - options.untuple_result = true; std::vector arg_buffers(args.size()); absl::c_transform(args, arg_buffers.begin(), [](PyBuffer* buf) { return buf->buffer(); }); TF_ASSIGN_OR_RETURN(output_buffers, - executable_->Execute(arg_buffers, options)); + executable_->Execute(arg_buffers, options_)); } auto traceback = Traceback::Get(); std::vector> outputs; @@ -87,8 +111,6 @@ PyExecutable::ExecuteOnLocalDevices( std::vector>> output_buffers; { py::gil_scoped_release gil_release; - ExecuteOptions options; - options.untuple_result = true; std::vector> arg_buffers(args.size()); for (int computation = 0; computation < args.size(); ++computation) { arg_buffers[computation].resize(args[computation].size()); @@ -96,7 +118,7 @@ PyExecutable::ExecuteOnLocalDevices( [](PyBuffer* buf) { return buf->buffer(); }); } TF_ASSIGN_OR_RETURN(output_buffers, executable_->ExecuteOnLocalDevices( - arg_buffers, options)); + arg_buffers, options_)); } auto traceback = Traceback::Get(); std::vector>> outputs; diff --git a/tensorflow/compiler/xla/python/py_executable.h b/tensorflow/compiler/xla/python/py_executable.h index 7f35f97f6e9..24f177261e7 100644 --- a/tensorflow/compiler/xla/python/py_executable.h +++ b/tensorflow/compiler/xla/python/py_executable.h @@ -37,7 +37,8 @@ class PyExecutable { public: PyExecutable(std::shared_ptr client, std::unique_ptr executable, - std::shared_ptr traceback); + std::shared_ptr traceback, + absl::optional fingerprint); ~PyExecutable(); std::shared_ptr client() const { return client_; } @@ -57,6 +58,10 @@ class PyExecutable { StatusOr>> Execute( absl::Span args); + // Same as above, but take as inputs `PjRtBuffer*`. Only targets C++ code. + StatusOr>> PjRtExecute( + absl::Span args); + StatusOr>>> ExecuteOnLocalDevices(absl::Span> args); @@ -64,6 +69,8 @@ class PyExecutable { Traceback* traceback() { return traceback_.get(); } + const PjRtExecutable& pjrt_executable() const { return *executable_; } + private: friend class PyClient; @@ -71,6 +78,14 @@ class PyExecutable { std::unique_ptr executable_; std::shared_ptr traceback_; + // Identical executables (i.e. representing the same program) will have the + // same fingerprint. nullopt on platforms or executables where fingerprints + // aren't implemented. + absl::optional fingerprint_; + + // The options to pass to `executable_.Execute`. + ExecuteOptions options_; + // Doubly-linked list of all executables known to the client. Protected by the // GIL. PyExecutable* next_; diff --git a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc index 7632f21d5b2..c6aff604aee 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/grpc_tpu_driver.cc @@ -657,7 +657,7 @@ void GrpcTpuStream::StreamWriterFn() { request_lock_.Unlock(); for (const auto& r : reqs) { - TraceMe activity(absl::StrCat("GrpcTpuStream::Send ")); + TraceMe activity("GrpcTpuStream::Send "); ::grpc::WriteOptions opts; opts.set_no_compression().clear_buffer_hint(); stream_->Write(r, opts); @@ -721,7 +721,7 @@ std::unique_ptr GrpcTpuStream::Allocate( absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::Allocate(num_bytes)")); + TraceMe activity("GrpcTpuStream::Allocate(num_bytes)"); req->mutable_alloc()->set_core_id(core_id); req->mutable_alloc()->set_region(region); req->mutable_alloc()->set_num_bytes(num_bytes); @@ -737,7 +737,7 @@ std::unique_ptr GrpcTpuStream::Allocate( absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::Allocate(shape)")); + TraceMe activity("GrpcTpuStream::Allocate(shape)"); req->mutable_alloc()->set_core_id(core_id); req->mutable_alloc()->set_region(region); *req->mutable_alloc()->mutable_shape() = shape; @@ -754,7 +754,7 @@ std::unique_ptr GrpcTpuStream::AllocateTuple( absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::AllocateTuple")); + TraceMe activity("GrpcTpuStream::AllocateTuple"); req->mutable_alloc_tuple()->set_core_id(core_id); req->mutable_alloc_tuple()->set_region(region); for (auto child : children) { @@ -771,7 +771,7 @@ std::shared_ptr GrpcTpuStream::Deallocate( std::unique_ptr handle, absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::Deallocate")); + TraceMe activity("GrpcTpuStream::Deallocate"); auto grpc_handle = static_cast(handle.get()); req->mutable_dealloc()->set_handle(grpc_handle->id().AsInt()); auto event = @@ -784,7 +784,7 @@ std::shared_ptr GrpcTpuStream::TransferToDevice( const void* src, BufferHandle* dst, absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::TransferToDevice")); + TraceMe activity("GrpcTpuStream::TransferToDevice"); req->mutable_transfer_to()->mutable_data()->assign( static_cast(src), dst->size_in_bytes()); req->mutable_transfer_to()->set_target_handle( @@ -799,7 +799,7 @@ std::shared_ptr GrpcTpuStream::TransferFromDevice( const BufferHandle* src, void* dst, absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::TransferFromDevice")); + TraceMe activity("GrpcTpuStream::TransferFromDevice"); req->mutable_transfer_from()->set_source_handle( static_cast(src)->id().AsInt()); EventId event_id = EventId::FromInt(req->operation_id()); @@ -818,8 +818,10 @@ std::shared_ptr GrpcTpuStream::TransferFromDeviceToDevice( absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::TransferFromDeviceToDevice", - req->operation_id())); + TraceMe activity([&req] { + return absl::StrCat("GrpcTpuStream::TransferFromDeviceToDevice", + req->operation_id()); + }); req->mutable_transfer_from_to()->set_source_handle( static_cast(src)->id().AsInt()); @@ -836,7 +838,7 @@ std::unique_ptr GrpcTpuStream::CompileProgram( absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::CompileProgram")); + TraceMe activity("GrpcTpuStream::CompileProgram"); *req->mutable_compile()->mutable_hlo_program() = source; req->mutable_compile()->set_num_replicas(num_replicas); EventId event_id = EventId::FromInt(req->operation_id()); @@ -861,7 +863,7 @@ std::unique_ptr GrpcTpuStream::LoadProgram( absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::LoadProgram")); + TraceMe activity("GrpcTpuStream::LoadProgram"); req->mutable_load()->set_core_id(core_id); auto grpc_handle = static_cast(handle); if (grpc_handle->id().client_id != driver_->client_id()) { @@ -884,7 +886,7 @@ std::shared_ptr GrpcTpuStream::UnloadProgram( absl::Span wait_for) { auto req = absl::make_unique(); InitializeRequest(req.get(), wait_for); - TraceMe activity(absl::StrCat("GrpcTpuStream::UnloadProgram")); + TraceMe activity("GrpcTpuStream::UnloadProgram"); req->mutable_unload()->set_loaded_program_handle( static_cast(handle.get())->id().AsInt()); auto event = diff --git a/tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h b/tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h index 285d59e2304..0c7cc370e2a 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h +++ b/tensorflow/compiler/xla/python/tpu_driver/platform/external/compat.h @@ -35,7 +35,13 @@ class Thread { class TraceMe { public: - explicit TraceMe(absl::string_view tag, int level = 1) {} + explicit TraceMe(absl::string_view name, int level = 1) {} + explicit TraceMe(std::string&& name, int level = 1) = delete; + explicit TraceMe(const std::string& name, int level = 1) = delete; + explicit TraceMe(const char* raw, int level = 1) + : TraceMe(absl::string_view(raw), level) {} + template + explicit TraceMe(NameGeneratorT name_generator, int level = 1) {} ~TraceMe() {} }; diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index ed9b80775d8..510175cebf6 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -539,7 +539,7 @@ PYBIND11_MODULE(xla_extension, m) { &PyClient::CreateDeviceToHostChannelHandle) .def("create_host_to_device_channel_handle", &PyClient::CreateHostToDeviceChannelHandle) - .def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"), + .def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"), py::arg("device") = nullptr, py::arg("force_copy") = false, py::arg("host_buffer_semantics") = PjRtBuffer::HostBufferSemantics::kZeroCopy) @@ -654,7 +654,7 @@ PYBIND11_MODULE(xla_extension, m) { PyTypeObject* buffer_type = reinterpret_cast(buffer.ptr()); buffer_type->tp_as_buffer = PyBuffer::BufferProtocol(); - py::class_> executable( + py::class_> executable( m, "Executable"); executable.def_property_readonly("client", &PyExecutable::client) .def("local_logical_device_ids", &PyExecutable::local_logical_device_ids) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ed786992e4f..f5618b95c3e 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -296,6 +296,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//tensorflow/core:lib", + "//third_party/eigen3", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", @@ -473,6 +474,7 @@ cc_library( "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:optional", @@ -1698,11 +1700,15 @@ cc_library( cc_library( name = "hlo_creation_utils", srcs = ["hlo_creation_utils.cc"], - hdrs = ["hlo_creation_utils.h"], + hdrs = [ + "hlo_creation_utils.h", + "//tensorflow/compiler/xla:literal_util", + ], deps = [ ":hlo", ":hlo_module_config", ":shape_inference", + "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", @@ -1813,6 +1819,21 @@ cc_library( ], ) +cc_library( + name = "comparison_expander", + srcs = ["comparison_expander.cc"], + hdrs = ["comparison_expander.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + ":op_expander_pass", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client/lib:comparators", + ], +) + cc_library( name = "scatter_expander", srcs = ["scatter_expander.cc"], @@ -1871,6 +1892,27 @@ cc_library( ], ) +tf_cc_test( + name = "triangular_solve_expander_test", + size = "medium", + srcs = ["triangular_solve_expander_test.cc"], + shard_count = 3, + deps = [ + ":hlo", + ":triangular_solve_expander", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:verified_hlo_module", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + ], +) + cc_library( name = "cholesky_expander", srcs = ["cholesky_expander.cc"], @@ -2235,6 +2277,7 @@ tf_cc_test( srcs = ["gather_expander_test.cc"], deps = [ ":gather_expander", + ":hlo_query", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_macros_header", @@ -3383,6 +3426,15 @@ cc_library( ], ) +cc_library( + name = "memory_space_assignment_repacking", + hdrs = ["memory_space_assignment_repacking.h"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + ], +) + cc_library( name = "memory_space_assignment", srcs = ["memory_space_assignment.cc"], @@ -3390,6 +3442,7 @@ cc_library( deps = [ ":heap_simulator", ":hlo_cost_analysis", + ":memory_space_assignment_repacking", ":memory_space_assignment_utils", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/core/lib/math:math_util", @@ -3911,6 +3964,39 @@ tf_cc_test( ], ) +cc_library( + name = "conditional_canonicalizer", + srcs = ["conditional_canonicalizer.cc"], + hdrs = ["conditional_canonicalizer.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + ], +) + +tf_cc_test( + name = "conditional_canonicalizer_test", + srcs = ["conditional_canonicalizer_test.cc"], + deps = [ + ":conditional_canonicalizer", + ":hlo", + ":hlo_matchers", + ":hlo_parser", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_get_dimension_size_rewriter", srcs = ["hlo_get_dimension_size_rewriter.cc"], @@ -4928,3 +5014,34 @@ cc_library( "//tensorflow/stream_executor/lib", ], ) + +cc_library( + name = "topk_rewriter", + srcs = ["topk_rewriter.cc"], + hdrs = ["topk_rewriter.h"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + ":pattern_matcher", + "//tensorflow/compiler/xla:shape_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:optional", + ], +) + +tf_cc_test( + name = "topk_rewriter_test", + srcs = ["topk_rewriter_test.cc"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_matchers", + ":topk_rewriter", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:test_macros_cpu", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 130661bf1cd..fa4d0e47a5d 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -428,6 +428,10 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { shape, hlo, zero, dims, AddReduce_computation)); } + // Move scalar multiply to the smallest side of convolution to + // reduce multiply computations. + Status ScalarMultiplyReduction(HloInstruction* dot); + // Convenience method for replacing an instruction with a bitcast. If operand // is not null, then the bitcast will use the specified operand instead of the // operand of the instruction. @@ -509,6 +513,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Tries to convert slice(reshape(X)) into reshape(slice(X)) StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + // Tries to convert slice(reverse(X)) into reverse(slice(X)) + StatusOr TryToReorderSliceAndReverse(HloInstruction* slice); + // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into // `(< a N)`. This is crucial for being able to figure out the loop trip // count. @@ -560,6 +567,200 @@ bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs, } } +namespace { + +float GetConstantValue(HloInstruction* inst) { + switch (inst->shape().element_type()) { + case BF16: + return static_cast(inst->literal().GetFirstElement()); + case F32: + return inst->literal().GetFirstElement(); + default: + LOG(FATAL) << "Unsupported data type: " << inst->shape().element_type(); + } +} + +bool IsOpCodeMultiplyCommutative(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kMultiply: + case HloOpcode::kTranspose: + case HloOpcode::kReshape: + case HloOpcode::kSelect: + return true; + default: + return false; + } +} + +std::unique_ptr MakeScalarInstruction(HloInstruction* target, + float multiplier) { + switch (target->shape().element_type()) { + case BF16: + return HloInstruction::CreateConstant(LiteralUtil::ConvertF32ToBF16( + LiteralUtil::CreateR0(multiplier))); + break; + case F32: + return HloInstruction::CreateConstant( + LiteralUtil::CreateR0(multiplier)); + break; + default: + LOG(FATAL) << "Unsupported data type: " << target->shape().element_type(); + } +} + +} // namespace + +Status AlgebraicSimplifierVisitor::ScalarMultiplyReduction( + HloInstruction* dot) { + // We only process bfloat16 and float32 for now. + if (dot->shape().element_type() != BF16 && + dot->shape().element_type() != F32) { + return Status::OK(); + } + + auto lhs = dot->mutable_operand(0); + auto rhs = dot->mutable_operand(1); + + const int64 dot_size = ShapeUtil::ElementsIn(dot->shape()); + const int64 lhs_size = ShapeUtil::ElementsIn(lhs->shape()); + const int64 rhs_size = ShapeUtil::ElementsIn(rhs->shape()); + + HloInstruction* target = nullptr; + // (current node, user, operand_index) + std::vector> operands; + std::vector users; + + // Find which side of dot has the smallest size: + // operand 0, operand 1, or output. + if (dot_size <= std::min(lhs_size, rhs_size)) { + target = dot; + if (dot_size < lhs_size) { + operands.emplace_back(lhs, dot, 0); + } + if (dot_size < rhs_size) { + operands.emplace_back(rhs, dot, 1); + } + } else if (lhs_size <= rhs_size) { + target = lhs; + if (lhs_size < rhs_size) { + operands.emplace_back(rhs, dot, 1); + } + if (lhs_size < dot_size && dot->user_count() == 1) { + users.push_back(dot->users().front()); + } + } else { + target = rhs; + if (rhs_size < lhs_size) { + operands.emplace_back(lhs, dot, 0); + } + if (rhs_size < dot_size && dot->user_count() == 1) { + users.push_back(dot->users().front()); + } + } + + std::vector values; + + // DFS to find scalar multiply ops from the operands. + while (!operands.empty()) { + HloInstruction* inst; + HloInstruction* user; + int64 index; + std::tie(inst, user, index) = operands.back(); + operands.pop_back(); + + // Skip the op types that are not commutative with multiply. + if (!IsOpCodeMultiplyCommutative(inst->opcode())) { + continue; + } + + HloInstruction* operand; + HloInstruction* multiplier; + // Pattern match a scalar multiply. + if (Match(inst, m::MultiplyAnyOrder( + m::Op(&operand), + m::Broadcast(m::ConstantScalar(&multiplier))))) { + CHECK_LT(index, user->operand_count()); + CHECK_EQ(inst, user->operands()[index]); + + // When found a scalar multiply, save its scalar value. + values.push_back(GetConstantValue(multiplier)); + // And remove the scalar multiply op. + TF_RETURN_IF_ERROR(user->ReplaceOperandWith(index, operand)); + inst = operand; + } + + // Push the operands of inst. + int64 i = 0; + for (auto* operand : inst->operands()) { + operands.emplace_back(operand, inst, i++); + } + } + + // DFS to find scalar multiply ops from the users. + while (!users.empty()) { + auto inst = users.back(); + users.pop_back(); + + if (!IsOpCodeMultiplyCommutative(inst->opcode())) { + continue; + } + + HloInstruction* operand; + HloInstruction* multiplier; + if (Match(inst, m::MultiplyAnyOrder( + m::Op(&operand), + m::Broadcast(m::ConstantScalar(&multiplier))))) { + values.push_back(GetConstantValue(multiplier)); + + TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(operand)); + inst = operand; + } + + // Process the instructions with only one user. + // Otherwise moving scalar multiply to the operands changes the values of + // other users. + if (inst->user_count() == 1) { + users.push_back(inst->users().front()); + } + } + + if (values.empty()) { + return Status::OK(); + } + + changed_ = true; + + // Combine all constant multipliers. + float multiplier = 1.0; + for (const float v : values) { + multiplier *= v; + } + + // Create a new const scalar multiply instruction. + HloInstruction* new_const_inst; + new_const_inst = + computation_->AddInstruction(MakeScalarInstruction(target, multiplier)); + + // Broadcast the scalar multiplier. + HloInstruction* new_broadcast = computation_->AddInstruction( + HloInstruction::CreateBroadcast(target->shape(), new_const_inst, {})); + // Create a new scalar multiply instruction. + HloInstruction* new_multiply = + computation_->AddInstruction(HloInstruction::CreateBinary( + target->shape(), HloOpcode::kMultiply, target, new_broadcast)); + CHECK_EQ(new_multiply->shape(), target->shape()); + + // Update the dependency with the rest of the instructions. + if (target == lhs) { + return dot->ReplaceOperandWith(0, new_multiply); + } else if (target == rhs) { + return dot->ReplaceOperandWith(1, new_multiply); + } else { + CHECK_EQ(target, dot); + return dot->ReplaceAllUsesWith(new_multiply); + } +} + void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction, HloInstruction* operand) { CHECK_EQ(1, instruction->operand_count()); @@ -1035,6 +1236,10 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( return Status::OK(); } + if (options_.is_layout_sensitive()) { + return Status::OK(); + } + // Check if we can merge "adjacent" slice operands which take slices from the // same other op. For simplicity we only merge unstrided slices. int64 concatenate_dimension = concatenate->concatenate_dimension(); @@ -1134,6 +1339,23 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate( operands[pad_value_operand]->mutable_operand(0), padding_config)); return ReplaceInstruction(concatenate, pad); } + + if (absl::c_count(operands, operands[0]) == operands.size() && + operands[0]->shape().dimensions(concatenate_dimension) == 1) { + Shape new_shape = operands[0]->shape(); + absl::InlinedVector broadcast_dims; + for (int64 i = 0; i < new_shape.rank(); ++i) { + if (i == concatenate_dimension) { + continue; + } + broadcast_dims.push_back(i); + } + new_shape.DeleteDimension(concatenate_dimension); + return ReplaceInstruction( + concatenate, + MakeBroadcastHlo(MakeReshapeHlo(new_shape, operands[0]).ValueOrDie(), + broadcast_dims, concatenate->shape())); + } return Status::OK(); } @@ -2098,11 +2320,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), AsInt64Slice( dot->dot_dimension_numbers().lhs_contracting_dimensions()))); - if (dot->shape().rank() != lhs->shape().rank()) { - std::vector lhs_broadcast_dims(lhs->shape().rank()); - absl::c_iota(lhs_broadcast_dims, 0); - new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( - dot->shape(), new_lhs, lhs_broadcast_dims)); + if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) { + new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type()); } TF_ASSIGN_OR_RETURN( HloInstruction * new_rhs, @@ -2111,6 +2330,15 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), AsInt64Slice( dot->dot_dimension_numbers().rhs_contracting_dimensions()))); + if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) { + new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type()); + } + if (dot->shape().rank() != lhs->shape().rank()) { + std::vector lhs_broadcast_dims(lhs->shape().rank()); + absl::c_iota(lhs_broadcast_dims, 0); + new_lhs = computation_->AddInstruction(HloInstruction::CreateBroadcast( + dot->shape(), new_lhs, lhs_broadcast_dims)); + } if (dot->shape().rank() != rhs->shape().rank()) { std::vector rhs_broadcast_dims( dot->dot_dimension_numbers().lhs_batch_dimensions_size()); @@ -2129,8 +2357,6 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // If the lhs or rhs have only batch and contracting dimensions, a dot can be // rewritten as reduce(mul(broadcast(transpose(x)),broadcast(transpose(y)))) if (options_.enable_dot_strength_reduction() && - (ShapeUtil::ElementIsFloating(dot->shape()) || - ShapeUtil::ElementIsComplex(dot->shape())) && ((dot->dot_dimension_numbers().lhs_batch_dimensions_size() + dot->dot_dimension_numbers().lhs_contracting_dimensions_size() == lhs->shape().rank()) || @@ -2144,6 +2370,10 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { AsInt64Slice(dot->dot_dimension_numbers().lhs_batch_dimensions()), AsInt64Slice( dot->dot_dimension_numbers().lhs_contracting_dimensions()))); + if (!ShapeUtil::SameElementType(dot->shape(), new_lhs->shape())) { + new_lhs = MakeConvertToHlo(new_lhs, dot->shape().element_type()); + } + TF_ASSIGN_OR_RETURN( HloInstruction * new_rhs, NormalizeDotOperandToBatchMajorAndContractingMinor( @@ -2151,6 +2381,9 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { AsInt64Slice(dot->dot_dimension_numbers().rhs_batch_dimensions()), AsInt64Slice( dot->dot_dimension_numbers().rhs_contracting_dimensions()))); + if (!ShapeUtil::SameElementType(dot->shape(), new_rhs->shape())) { + new_rhs = MakeConvertToHlo(new_rhs, dot->shape().element_type()); + } int64 lhs_outer_dims = lhs->shape().rank() - @@ -2192,9 +2425,9 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { std::vector reduce_dims( dot->dot_dimension_numbers().lhs_contracting_dimensions_size()); PrimitiveType dot_type = - ShapeUtil::ElementIsComplex(dot->shape()) - ? dot->shape().element_type() - : dot->shape().element_type() == F64 ? F64 : F32; + ShapeUtil::ElementIsFloating(dot->shape()) + ? (dot->shape().element_type() == F64 ? F64 : F32) + : dot->shape().element_type(); new_dot = AsType(new_dot, dot_type); const int64 outer_dims = std::max(rhs_outer_dims, lhs_outer_dims); absl::c_iota( @@ -2474,6 +2707,70 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { } } + { + HloInstruction *a, *b, *c1, *c2; + // Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y), + // constant1*constant2) + if (Match(multiply, + m::Multiply( + m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)), + m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) { + TF_ASSIGN_OR_RETURN(auto* product_of_constants, + MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); + if (ShapeUtil::IsScalar(product_of_constants->shape()) && + !ShapeUtil::IsScalar(multiply->shape())) { + product_of_constants = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + multiply->shape(), product_of_constants, {})); + } + + return ReplaceWithNewInstruction( + multiply, + HloInstruction::CreateBinary( + multiply->shape(), HloOpcode::kMultiply, + computation_->AddInstruction(HloInstruction::CreateBinary( + multiply->shape(), HloOpcode::kMultiply, a, b)), + product_of_constants)); + } + } + + { + HloInstruction *a, *b, *constant, *op; + // Mul(Mul(a, constant1), Broadcast(b)) => + // Mul(Broadcast(Mul(b, constant1), a)) + if (Match(multiply, + m::MultiplyAnyOrder(m::MultiplyAnyOrder(m::NonConstant(&a), + m::Constant(&constant)), + m::Op(&op))) || + Match(multiply, + m::MultiplyAnyOrder( + m::MultiplyAnyOrder(m::NonConstant(&a), + m::Broadcast(m::Constant(&constant))), + m::Op(&op)))) { + // Check that the other side was a broadcast, and not of a constant. + if (ShapeUtil::IsScalar(constant->shape()) && + Match(op, m::Broadcast(m::NonConstant()))) { + auto dims = op->dimensions(); + b = op->mutable_operand(0); + if (!ShapeUtil::IsScalar(b->shape())) { + constant = computation_->AddInstruction( + HloInstruction::CreateBroadcast(b->shape(), constant, {})); + } + + auto new_mul = + computation_->AddInstruction(HloInstruction::CreateBinary( + b->shape(), HloOpcode::kMultiply, b, constant)); + + return ReplaceWithNewInstruction( + multiply, + HloInstruction::CreateBinary( + multiply->shape(), HloOpcode::kMultiply, a, + computation_->AddInstruction(HloInstruction::CreateBroadcast( + multiply->shape(), new_mul, dims)))); + } + } + } + VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]"; HloInstruction *a, *c1, *c2; if (Match(multiply, @@ -3087,6 +3384,17 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { HloOpcode::kMultiply, lhs, lhs)); } + // Pow(A, 3) is used in GELU. + VLOG(10) << "trying transform [pow(A, 3) => A*A*A]: " << power->ToString(); + if (IsAll(rhs, 3)) { + HloInstruction* tmp = + computation_->AddInstruction(HloInstruction::CreateBinary( + power->shape(), HloOpcode::kMultiply, lhs, lhs)); + return ReplaceWithNewInstruction( + power, HloInstruction::CreateBinary(power->shape(), + HloOpcode::kMultiply, lhs, tmp)); + } + VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { return ReplaceWithNewInstruction( @@ -3576,6 +3884,52 @@ StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReshape( return false; } +// Allowing a slice to move through a reverse with any necessary updates to the +// slice config. +StatusOr AlgebraicSimplifierVisitor::TryToReorderSliceAndReverse( + HloInstruction* slice) { + VLOG(2) << "Entered TryToReorderSliceAndReverse for slice:" + << slice->ToString(); + if (Match(slice, m::Slice(m::Reverse()))) { + HloInstruction* reverse = slice->mutable_operand(0); + HloInstruction* reverse_operand = reverse->mutable_operand(0); + std::vector new_starts = slice->slice_starts(); + std::vector new_limits = slice->slice_limits(); + std::vector new_strides = slice->slice_strides(); + for (auto rdim : reverse->dimensions()) { + int64 start = slice->slice_starts(rdim); + int64 limit = slice->slice_limits(rdim); + int64 stride = slice->slice_strides(rdim); + // find_nth allows us to compute the appropriate index to begin + // with during reverse even in the presence of non-unit strides + int64 find_nth = (limit - start - 1) / stride; + find_nth = start + find_nth * stride; + limit = find_nth + 1; + new_starts[rdim] = + (reverse->shape().dimensions(rdim) - start) - (limit - start); + new_limits[rdim] = reverse->shape().dimensions(rdim) - start; + VLOG(2) << "Analyzing dim:" << rdim << " (start,limit):" << start << "," + << limit << " and new (start, limit):" << new_starts[rdim] << "," + << new_limits[rdim]; + } + // New slice formed from the reverse_operand, but strides and shape of the + // slice output remains the same. New slice's starts and limits are updated + // for ONLY the reversed dimensions as indicated above. + HloInstruction* new_slice = computation_->AddInstruction( + HloInstruction::CreateSlice(slice->shape(), reverse_operand, new_starts, + new_limits, new_strides)); + simplifier_->UpdateLayout(new_slice->mutable_shape()); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + slice, HloInstruction::CreateReverse(new_slice->shape(), new_slice, + reverse->dimensions()))); + // We do not delete the old reverse, since there might be another + // consumer of that reverse (i.e., full reverse output). DCE should take + // care of any deletion that is necessary if there was no use of reverse. + return true; + } + return false; +} + Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { // Delete no-op slices, i.e. where shape = operand shape. if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) { @@ -3730,6 +4084,15 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { if (replaced) { return Status::OK(); } + + bool reversed = false; + if (Match(slice, m::Slice(m::Reverse(m::Op())))) { + TF_ASSIGN_OR_RETURN(reversed, TryToReorderSliceAndReverse(slice)); + } + if (reversed) { + return Status::OK(); + } + return Status::OK(); } @@ -3798,8 +4161,8 @@ Status AlgebraicSimplifierVisitor::HandleDynamicSlice( return ReplaceWithNewInstruction(dynamic_slice, std::move(new_broadcast)); } - // Convert a dynamic slice into a slice if all offsets are constant and the - // operand is not constant. If ev + // Convert a dynamic slice into a slice if all offsets are constant and the + // operand is not constant. if (operand->opcode() != HloOpcode::kConstant && absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1, dynamic_slice->operands().end()), @@ -4137,13 +4500,13 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { new_dnums.add_rhs_contracting_dimensions( dnums.rhs_batch_dimensions(batch_dim)); new_dnums.add_lhs_contracting_dimensions( - dnums.rhs_batch_dimensions(batch_dim)); + dnums.lhs_batch_dimensions(batch_dim)); ++removed_dims; } else { new_dnums.add_rhs_batch_dimensions( dnums.rhs_batch_dimensions(batch_dim)); new_dnums.add_lhs_batch_dimensions( - dnums.rhs_batch_dimensions(batch_dim)); + dnums.lhs_batch_dimensions(batch_dim)); } } std::vector reduce_dims; @@ -4697,15 +5060,17 @@ StatusOr AlgebraicSimplifierVisitor::SwapConvOperands( for (int64 spatial_dim = 0; spatial_dim < dnums.input_spatial_dimensions_size(); ++spatial_dim) { const int64 kernel_size = window_dims[spatial_dim].size(); - kernel_product *= kernel_size; const int64 dilated_kernel_size = 1 + (kernel_size - 1) * window_dims[spatial_dim].window_dilation(); const int64 input_size = input->shape().dimensions(dnums.input_spatial_dimensions(spatial_dim)); - swapped_kernel_product *= input_size; const int64 dilated_input_size = 1 + (input_size - 1) * window_dims[spatial_dim].base_dilation(); + // Don't decide to swap if the input size is one, since many convolution + // implementations can easily hand that special case efficiently. + kernel_product *= kernel_size; + swapped_kernel_product *= input_size == 1 ? kernel_size : input_size; auto new_dim = swapped_window.add_dimensions(); new_dim->set_size(input_size); @@ -4896,6 +5261,10 @@ StatusOr AlgebraicSimplifierVisitor::SimplifyConvToDot( Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution) { + if (options_.enable_scalar_multiply_reduction()) { + TF_RETURN_IF_ERROR(ScalarMultiplyReduction(convolution)); + } + // Zero-sized input or filter. if (ShapeUtil::IsZeroElementArray(convolution->operand(0)->shape()) || ShapeUtil::IsZeroElementArray(convolution->operand(1)->shape())) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 9f29df3c209..9f2a3404116 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -86,6 +86,17 @@ class AlgebraicSimplifierOptions { } bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; } + // Move constant scalar multiply to one operand or output of convolutions with + // the smallest tensor size, to reduce the number of scalar multiply. + void set_enable_scalar_multiply_reduction( + bool enable_scalar_multiply_reduction) { + enable_scalar_multiply_reduction_ = enable_scalar_multiply_reduction; + } + + bool enable_scalar_multiply_reduction() const { + return enable_scalar_multiply_reduction_; + } + // If enable_window_reduce_replacement is true, the kReduceWindow instruction // can be optimized by replacement with simpler operations. void set_enable_window_reduce_to_reduce_replacement( @@ -146,6 +157,7 @@ class AlgebraicSimplifierOptions { bool enable_dot_to_multiply_rewrite_{true}; bool enable_conv_simplification_{true}; bool enable_conv_operand_swap_{true}; + bool enable_scalar_multiply_reduction_{false}; bool enable_window_reduce_to_reduce_replacement_{true}; bool enable_reduce_of_reshape_{true}; bool replace_transpose_with_bitcast_{true}; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 779d6c9cdc5..95700b2a994 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -117,6 +117,52 @@ TEST_F(AlgebraicSimplifierTest, FactorFpAddition) { m::ConstantScalar(0.125)))); } +// (A*C1) * (B*C2) => (A*B)*(C1*C2) +TEST_F(AlgebraicSimplifierTest, MultiplyChain) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + c = f32[] constant(2) + d = f32[] constant(4) + x = f32[] multiply(p0, c) + y = f32[] multiply(p1, d) + ROOT z = f32[] multiply(x, y) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::MultiplyAnyOrder(m::Parameter(0), m::Parameter(1)), + m::MultiplyAnyOrder(m::ConstantScalar(2), m::ConstantScalar(4))))); +} + +// MUL(MUL(X, BROADCAST(constant)), BROADCAST(Y)) ==> +// MUL(X, BROADCAST(MUL(Y, BROADCAST(constant)))) +TEST_F(AlgebraicSimplifierTest, MultiplyBroadcastReassoc) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[2,2] parameter(0) + p1 = f32[] parameter(1) + b = f32[] constant(2) + c = f32[2, 2] broadcast(b), dimensions={} + x = f32[2,2] multiply(p0, c) + y = f32[2,2] broadcast(p1), dimensions={} + ROOT z = f32[2,2] multiply(y, x) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::Parameter(0), m::Broadcast(m::MultiplyAnyOrder( + m::Parameter(1), m::Constant()))))); +} + // A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2. TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) { const char* kModuleStr = R"( @@ -1568,6 +1614,32 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(0)))); } +// Test that pow(A, 3) is simplified to A*A*A. +TEST_F(AlgebraicSimplifierTest, Pow3) { + auto m = CreateNewVerifiedModule(); + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* three = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, three)); + + auto computation = m->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Power(m::Parameter(0), m::Op().Is(three)))); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(m.get()).ValueOrDie()); + + EXPECT_THAT( + computation->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), + m::Multiply(m::Parameter(0), m::Parameter(0))))); +} + // Test that pow(A, -1) is simplified to 1/A. TEST_F(AlgebraicSimplifierTest, PowNegative1) { auto m = CreateNewVerifiedModule(); @@ -2014,6 +2086,80 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { EXPECT_THAT(computation->root_instruction(), param0); } +TEST_F(AlgebraicSimplifierTest, SliceReverse) { + const char* const hlo_string = R"( +HloModule module + +ENTRY test { + param = f32[6,7,32] parameter(0) + constant = f32[] constant(0) + pad = f32[8,7,32] pad(param, constant), padding=1_1x0_0x0_0 + rev = f32[8,7,32] reverse(pad), dimensions={0,2} + slice = f32[1,7,32] slice(rev), slice={[2:3:1], [0:7:1], [0:32:1]} + ROOT tuple = (f32[1,7,32]) tuple(slice) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloComputation* computation = module->entry_computation(); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Tuple(m::Reverse(m::Slice(m::Pad()))))); + const HloInstruction* slice = + computation->root_instruction()->operand(0)->operand(0); + EXPECT_TRUE( + ShapeUtil::Equal(slice->shape(), ShapeUtil::MakeShape(F32, {1, 7, 32}))); + // slice start,limit of 0th and 2nd dimensions are changed + // while 1st dimension's slice start, limit remains the same since + // it is not reversed. + EXPECT_EQ(slice->slice_starts(0), 5); + EXPECT_EQ(slice->slice_limits(0), 6); + EXPECT_EQ(slice->slice_starts(1), 0); + EXPECT_EQ(slice->slice_limits(1), 7); + EXPECT_EQ(slice->slice_starts(2), 0); + EXPECT_EQ(slice->slice_limits(2), 32); + EXPECT_EQ(slice->slice_strides(0), 1); + EXPECT_EQ(slice->slice_strides(1), 1); + EXPECT_EQ(slice->slice_strides(2), 1); +} + +TEST_F(AlgebraicSimplifierTest, SliceReverseNonUnitEvenOddStrides) { + const char* const hlo_string = R"( +HloModule module + +ENTRY test { + param = f32[6,7,32] parameter(0) + constant = f32[] constant(0) + pad = f32[8,7,32] pad(param, constant), padding=1_1x0_0x0_0 + rev = f32[8,7,32] reverse(pad), dimensions={0,1,2} + slice = f32[1,2,7] slice(rev), slice={[2:3:2], [0:7:4], [0:32:5]} + ROOT tuple = (f32[1,2,7]) tuple(slice) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + HloComputation* computation = module->entry_computation(); + EXPECT_THAT(computation->root_instruction(), + GmockMatch(m::Tuple(m::Reverse(m::Slice(m::Pad()))))); + const HloInstruction* slice = + computation->root_instruction()->operand(0)->operand(0); + EXPECT_TRUE( + ShapeUtil::Equal(slice->shape(), ShapeUtil::MakeShape(F32, {1, 2, 7}))); + // slice start,limit of all dimensions are changed + EXPECT_EQ(slice->slice_starts(0), 5); + EXPECT_EQ(slice->slice_limits(0), 6); + EXPECT_EQ(slice->slice_starts(1), 2); + EXPECT_EQ(slice->slice_limits(1), 7); + EXPECT_EQ(slice->slice_starts(2), 1); + EXPECT_EQ(slice->slice_limits(2), 32); + EXPECT_EQ(slice->slice_strides(0), 2); + EXPECT_EQ(slice->slice_strides(1), 4); + EXPECT_EQ(slice->slice_strides(2), 5); +} + // Test that empty operands of concatenates are removed. TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { auto m = CreateNewVerifiedModule(); @@ -4677,6 +4823,25 @@ TEST_F(AlgebraicSimplifierTest, SliceOfConcatNonScalarInput) { EXPECT_EQ(root->slice_limits(0), 2); } +TEST_F(AlgebraicSimplifierTest, ConcatToBroadcast) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + p = f32[2,1,4] parameter(0) + ROOT concat = f32[2,6,4] concatenate(p,p,p,p,p,p), dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0))))); +} + TEST_F(AlgebraicSimplifierTest, NegateNegate) { const char* hlo_string = R"( HloModule module @@ -5197,6 +5362,59 @@ ENTRY AddBroadcastZeroWithDynamicSlice { EXPECT_THAT(root->operand(1)->opcode(), HloOpcode::kPad); } +TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReduction) { + const char* hlo_string = R"( +HloModule ConstScalarMultiply +ENTRY ConstScalarMultiply { + param0 = f32[16,512,4096]{2,1,0} parameter(0) + constant.0 = f32[] constant(0.5) + broadcast.0 = f32[16,512,4096] broadcast(constant.0), dimensions={} + multiply.0 = f32[16,512,4096]{2,1,0} multiply(param0, broadcast.0) + param1 = f32[16,512,4096]{2,1,0} parameter(1) + multiply.1 = f32[16,512,4096]{2,1,0} multiply(multiply.0, param1) + param2 = f32[16,512,1024]{2,1,0} parameter(2) + constant.1 = f32[] constant(1.109) + broadcast.1 = f32[16,512,1024] broadcast(constant.1), dimensions={} + multiply.2 = f32[16,512,1024]{2,1,0} multiply(param2, broadcast.1) + ROOT convolution = f32[4096,1024,1]{1,0,2} convolution(multiply.1, multiply.2), window={size=16}, dim_labels=0fb_0io->bf0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifierOptions options; + options.set_enable_scalar_multiply_reduction(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kMultiply); + EXPECT_THAT(root, + GmockMatch(m::MultiplyAnyOrder( + m::Op(), m::Broadcast(m::ConstantScalar(0.5f * 1.109f))))); +} + +TEST_F(AlgebraicSimplifierTest, ScalarMultiplyReductionMultiUser) { + const char* hlo_string = R"( +HloModule ConstScalarMultiply +ENTRY ConstScalarMultiply { + param0 = f32[16,512,1024] parameter(0) + param1 = f32[4096,1024,1] parameter(1) + convolution = f32[16,512,4096] convolution(param0, param1), window={size=1}, dim_labels=0bf_oi0->0bf + constant.1 = f32[] constant(0.5) + broadcast.1 = f32[16,512,4096] broadcast(constant.1), dimensions={} + multiply.1 = f32[16,512,4096] multiply(convolution, broadcast.1) + param2 = f32[16,512,4096] parameter(2) + multiply.2 = f32[16,512,4096] multiply(convolution, param2) + ROOT add.1 = f32[16,512,4096] add(multiply.1, multiply.2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifierOptions options; + options.set_enable_scalar_multiply_reduction(true); + AlgebraicSimplifier simplifier(options); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); +} + INSTANTIATE_TEST_SUITE_P(DotOfConcatSimplificationTestInstantiation, DotOfConcatSimplificationTest, ::testing::ValuesIn(kDotOfConcatTestSpecs)); @@ -6145,10 +6363,10 @@ TEST_F(AlgebraicSimplifierTest, ReduceOfBatchDotToContractingDimension) { } test { p0 = f32[32,8,5,6] parameter(0) - p1 = f32[32,8,6,7] parameter(1) + p1 = f32[8,32,6,7] parameter(1) d = f32[32,8,5,7] dot(p0, p1), lhs_batch_dims={0,1}, - rhs_batch_dims={0,1}, + rhs_batch_dims={1,0}, rhs_contracting_dims={2}, lhs_contracting_dims={3} c = f32[] constant(0) diff --git a/tensorflow/compiler/xla/service/comparison_expander.cc b/tensorflow/compiler/xla/service/comparison_expander.cc new file mode 100644 index 00000000000..5c88ff8cae2 --- /dev/null +++ b/tensorflow/compiler/xla/service/comparison_expander.cc @@ -0,0 +1,133 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/comparison_expander.h" + +#include "tensorflow/compiler/xla/client/lib/comparators.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { + +HloInstruction* BitcastConvertFloatingPointToIntegral( + HloComputation* computation, HloInstruction* value, + const Shape& signed_shape, const Shape& unsigned_shape, + HloInstruction* zero, HloInstruction* max_value) { + // Switch from a floating point value to a integer value in such a way that + // when using the integer value to compare, we get the same result for normal + // values, and -Nan is treated as the smallest value, and Nan is treated as + // the largest value. + // If f is a float, and + // x = bit_cast(f); + // y = x < 0 ? numeric_limits::max() - x : x; + // then y is ordered as an int32 such that finite values have the obvious + // order, -0 is ordered before 0, and -NaN and NaN appear at the beginning + // and end of the ordering. + // Note that in order to avoid -x to overflow, we calculate + // numeric_limits::max() - x as unsigned, and then convert back to + // signed. + auto signed_value = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(signed_shape, value)); + auto unsigned_value = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(unsigned_shape, value)); + auto flipped_value = computation->AddInstruction(HloInstruction::CreateBinary( + unsigned_shape, HloOpcode::kSubtract, max_value, unsigned_value)); + flipped_value = computation->AddInstruction( + HloInstruction::CreateBitcastConvert(signed_shape, flipped_value)); + auto compare_shape = signed_shape; + compare_shape.set_element_type(PRED); + auto is_negative = computation->AddInstruction(HloInstruction::CreateCompare( + compare_shape, signed_value, zero, ComparisonDirection::kLt)); + return computation->AddInstruction( + HloInstruction::CreateTernary(signed_shape, HloOpcode::kSelect, + is_negative, flipped_value, signed_value)); +} + +bool ComparisonExpander::InstructionMatchesPattern( + HloInstruction* instruction) { + if (HloCompareInstruction* compare = + dynamic_cast(instruction)) { + HloInstruction* lhs = instruction->operands()[0]; + if (compare->type() == Comparison::Type::kFloatTotalOrder && + primitive_util::IsFloatingPointType(lhs->shape().element_type())) { + return true; + } + } + return false; +} + +StatusOr ComparisonExpander::ExpandInstruction( + HloInstruction* instruction) { + CHECK(instruction->opcode() == HloOpcode::kCompare); + HloCompareInstruction* compare = + static_cast(instruction); + CHECK(compare->type() == Comparison::Type::kFloatTotalOrder); + HloComputation* computation = instruction->parent(); + HloInstruction* lhs = instruction->operands()[0]; + HloInstruction* rhs = instruction->operands()[1]; + Shape compare_shape = lhs->shape(); + PrimitiveType compare_type = compare_shape.element_type(); + CHECK(primitive_util::IsFloatingPointType(compare_type)); + // Special-case handling for BF16. We currently do not support direct + // comparisons with BF16, so we convert to F32 and then use the F32 + // comparison logic. + if (compare_type == BF16) { + compare_type = F32; + compare_shape.set_element_type(compare_type); + lhs = computation->AddInstruction( + HloInstruction::CreateConvert(compare_shape, lhs)); + rhs = computation->AddInstruction( + HloInstruction::CreateConvert(compare_shape, rhs)); + } + + int64 bit_width = primitive_util::BitWidth(compare_type); + PrimitiveType signed_type = + primitive_util::SignedIntegralTypeForBitWidth(bit_width); + PrimitiveType unsigned_type = + primitive_util::UnsignedIntegralTypeForBitWidth(bit_width); + auto signed_shape = compare_shape; + signed_shape.set_element_type(signed_type); + auto unsigned_shape = compare_shape; + unsigned_shape.set_element_type(unsigned_type); + auto zero_value = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(signed_type))); + zero_value = computation->AddInstruction(HloInstruction::CreateBroadcast( + signed_shape, zero_value, zero_value->shape().dimensions())); + auto max_signed = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::MaxValue(signed_type))); + auto max_shape = max_signed->shape(); + max_shape.set_element_type(unsigned_type); + auto max_unsigned = computation->AddInstruction( + HloInstruction::CreateConvert(max_shape, max_signed)); + auto max_value = computation->AddInstruction(HloInstruction::CreateBroadcast( + unsigned_shape, max_unsigned, max_shape.dimensions())); + lhs = BitcastConvertFloatingPointToIntegral( + computation, lhs, signed_shape, unsigned_shape, zero_value, max_value); + rhs = BitcastConvertFloatingPointToIntegral( + computation, rhs, signed_shape, unsigned_shape, zero_value, max_value); + auto new_compare = computation->AddInstruction(HloInstruction::CreateCompare( + instruction->shape(), lhs, rhs, compare->direction(), + Comparison::Type::kSigned)); + VLOG(2) << "New comparison instruction for total order:" + << new_compare->ToString() << "\n"; + return new_compare; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/comparison_expander.h b/tensorflow/compiler/xla/service/comparison_expander.h new file mode 100644 index 00000000000..df8b5dc0137 --- /dev/null +++ b/tensorflow/compiler/xla/service/comparison_expander.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +// A pass which performs expansion of the comparison operator to support total +// order comparison of floating point numbers. +class ComparisonExpander : public OpExpanderPass { + public: + explicit ComparisonExpander() = default; + ~ComparisonExpander() override = default; + absl::string_view name() const override { return "comparison-expander"; } + + private: + // Returns `true` if `instruction` should be expanded by this pass. + bool InstructionMatchesPattern(HloInstruction* instruction) override; + // Returns a replacement for `instruction`, or nullptr if no replacement is + // needed (e.g. only the to_apply subcomputation of the instruction was + // modified). + StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPARISON_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index ce9c8a4ea62..f8e4f591a5d 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -92,6 +92,7 @@ CompileOnlyService::CompileAheadOfTime( execution_options.mutable_device_assignment())); } execution_options.set_use_spmd_partitioning(options.use_spmd_partitioning()); + execution_options.set_deduplicate_hlo(options.deduplicate_hlo()); for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_host_program_shape()); *execution_options.mutable_shape_with_output_layout() = diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 57b24e372e6..312a068ba65 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -77,6 +77,7 @@ class AotCompilationOptions { virtual int64 replica_count() const { return 0; } virtual int64 num_cores() const { return 0; } virtual bool use_spmd_partitioning() const { return false; } + virtual bool deduplicate_hlo() const { return false; } // Optional allocator that may be used for allocating temp space on the device // during compilation. diff --git a/tensorflow/compiler/xla/service/conditional_canonicalizer.cc b/tensorflow/compiler/xla/service/conditional_canonicalizer.cc new file mode 100644 index 00000000000..8af8e11febd --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_canonicalizer.cc @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { +namespace { +Status CanonicalizeNonTupleConditional(HloInstruction* conditional) { + TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); + for (auto* branch : conditional->called_computations()) { + HloInstruction* root = branch->root_instruction(); + TF_RET_CHECK(!root->shape().IsTuple()); + + HloInstruction* tuple = + branch->AddInstruction(HloInstruction::CreateTuple({root})); + branch->set_root_instruction(tuple, /*accept_different_shape=*/true); + } + auto parent = conditional->parent(); + auto root_shape = conditional->shape(); + auto new_shape = ShapeUtil::MakeTupleShape({root_shape}); + auto new_conditional = + parent->AddInstruction(conditional->CloneWithNewShape(new_shape)); + auto gte = parent->AddInstruction( + HloInstruction::CreateGetTupleElement(root_shape, new_conditional, 0)); + TF_RETURN_IF_ERROR(parent->ReplaceInstruction(conditional, gte)); + return Status::OK(); +} +} // namespace + +StatusOr ConditionalCanonicalizer::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "ConditionalCanonicalizer::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + for (auto* inst : comp->MakeInstructionPostOrder()) { + if (inst->opcode() == HloOpcode::kConditional && + !inst->shape().IsTuple()) { + TF_RETURN_IF_ERROR(CanonicalizeNonTupleConditional(inst)); + changed = true; + } + } + } + XLA_VLOG_LINES( + 2, "ConditionalCanonicalizer::Run(), after:\n" + module->ToString()); + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_canonicalizer.h b/tensorflow/compiler/xla/service/conditional_canonicalizer.h new file mode 100644 index 00000000000..a390d87a007 --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_canonicalizer.h @@ -0,0 +1,38 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Canonicalize output of conditionals, make non-tuple outputs into tuple with +// single element output. After this pass, all conditional instructions have +// tuple outputs. +class ConditionalCanonicalizer : public HloModulePass { + public: + absl::string_view name() const override { + return "conditional canonicalizer"; + } + + StatusOr Run(HloModule* module) override; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ diff --git a/tensorflow/compiler/xla/service/conditional_canonicalizer_test.cc b/tensorflow/compiler/xla/service/conditional_canonicalizer_test.cc new file mode 100644 index 00000000000..498260cbabf --- /dev/null +++ b/tensorflow/compiler/xla/service/conditional_canonicalizer_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class ConditionalCanonicalizerTest : public HloTestBase { + protected: + ConditionalCanonicalizerTest() {} +}; + +TEST_F(ConditionalCanonicalizerTest, DenseArrayConditionalRewrite) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule _ +true_branch { + true_param = (s32[3,2]) parameter(0) + ROOT root = s32[] constant(0) +} + +false_branch { + false_param = (s32[3,2]) parameter(0) + ROOT root = s32[] constant(1) +} + +ENTRY entry { + param0 = s32[3,2] parameter(0) + branch = pred[] constant(false) + param_tuple = (s32[3 ,2]) tuple(param0) + ROOT conditional = s32[] conditional(branch, param_tuple, param_tuple), + true_computation=true_branch, false_computation=false_branch +} +)") + .ValueOrDie(); + ConditionalCanonicalizer pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::GetTupleElement(op::Conditional())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc index 6db4c3eb6d4..cdda0aeb925 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -46,161 +45,81 @@ limitations under the License. namespace xla { -namespace { +namespace conditional_opt { -struct ConditionalBoundary { - ConditionalBoundary(HloInstruction* op, int64 op_index, HloInstruction* usr) - : operand(op), operand_index(op_index), user(usr) {} - // `operand` is one of `user`'s operand. - - // Instruction that remains in the conditional but one of its user - // is moved out of conditonal. - HloInstruction* operand; - // operand_index for `operand` in the `user`. - int64 operand_index; - // Instruction that moved out of conditional. - HloInstruction* user; -}; - -// Visit the root instructions to its operands follow BFS. -// Will visit an instructions after all its users have been visited. Parameters -// are not visited. -class BranchVisitor { +class BoundaryVisitor { public: - explicit BranchVisitor(const HloComputation* branch_computation) { - HloInstruction* root_inst = branch_computation->root_instruction(); - worklist_.push_back(root_inst); - visited_.insert(root_inst); - for (auto parameter_inst : branch_computation->parameter_instructions()) { - parameter_instructions_.insert(parameter_inst); - } + // start with an existing conditional computation. + explicit BoundaryVisitor(HloInstruction* conditional) { + Boundary b(Boundary::Position::kInsideBranch); + b.mutable_operands().push_back(conditional); + worklist_.push_back(b); } - // Get next intruction to visit. - HloInstruction* GetNextInstruction() { - if (!worklist_.empty()) { - HloInstruction* inst = worklist_.front(); + // Start with an empty work list. + BoundaryVisitor() {} + // Get next boundary to visit. + Boundary PopNextBoundary() { + CHECK(!worklist_.empty()); + Boundary b = worklist_.front(); + worklist_.pop_front(); + // if b is already visited, it must have multiple users and is already in + // new boundaries. Skip it. Only checking the first operand of b because b + // is expected to have at least one operand, and all the operands in b + // must be identical instructions from different branches for b to be moved. + while (!worklist_.empty() && ContainsKey(visited_, b.operands()[0])) { + b = worklist_.front(); worklist_.pop_front(); - return inst; } - return nullptr; + visited_.insert(b.operands()[0]); + return b; + } + void AddToWorkList(const Boundary& b) { + CHECK(!b.operands().empty()); + worklist_.push_back(b); } - // Add operands of one instruction to worklist for further visit. - void AddInstructionOperands(HloInstruction* inst) { - int64 operand_count = inst->operand_count(); - for (int i = 0; i < operand_count; i++) { - HloInstruction* operand = inst->mutable_operand(i); - if (ContainsKey(visited_, operand)) { - continue; + bool HasNextBoundary() { + while (!worklist_.empty()) { + Boundary b = worklist_.front(); + if (!ContainsKey(visited_, b.operands()[0])) { + break; } - bool all_user_visited = std::all_of( - operand->users().begin(), operand->users().end(), - [&](HloInstruction* user) { return ContainsKey(visited_, user); }); - - if (!all_user_visited) { - continue; - } - // Do not visit parameter_instructions. - if (ContainsKey(parameter_instructions_, operand)) { - // Add the operand and this instruction to the boundaries. - boundaries_.emplace_back(operand, i, inst); - continue; - } - worklist_.push_back(operand); - visited_.insert(operand); + worklist_.pop_front(); } - } - - // Add instruction and its users to conditional boundaries. - void AddInstructionToBoundary(HloInstruction* inst) { - for (auto user : inst->users()) { - boundaries_.emplace_back(inst, user->operand_index(inst), user); - } - } - - // Add instruction to the to be removed instructions set and vector. - void AddInstructionToHoist(HloInstruction* inst) { - instructions_to_hoist_set_.insert(inst); - instructions_to_hoist_.emplace_back(inst); - } - - // If visitor has next instruction to visit. - bool HasNextInstruction() const { return !worklist_.empty(); } - - // If there is no hoist intruction. - int64 HoistInstructionSize() { return instructions_to_hoist_.size(); } - - // Get boundaries of this branch. - const std::vector& boundaries() const { - return boundaries_; - } - - // Get instructions to hoist in this branch. - const std::vector& instructions_to_hoist() const { - return instructions_to_hoist_; - } - - // Get hoist instruction set in this branch. - const std::unordered_set& instructions_to_hoist_set() const { - return instructions_to_hoist_set_; + return !worklist_.empty(); } private: // worklist is the deque that contains instructions to be visited. - std::deque worklist_; - - // instructions that has been visited. - std::unordered_set visited_; - - // parameter instructions of the branch. - std::unordered_set parameter_instructions_; - - // Boundaries contains the set of instructions that its operand is within - // conditional but it can be hoist out of conditional. - std::vector boundaries_; - - // Instructions to hoist. - std::unordered_set instructions_to_hoist_set_; - - // Instructions to hoist, the order within this vector is BFS and - // an instruction's order will always be after its users. - std::vector instructions_to_hoist_; + std::deque worklist_; + absl::flat_hash_set visited_; }; -// Returns true if `instruction` is worth hoisting out. -bool WorthHoisting(HloInstruction* instruction) { - for (const auto* operand : instruction->operands()) { - // Only move out instructions that won't share the same operand - // to avoid copy of the operand. - if (operand->user_count() > 1) { - return false; - } - } - switch (instruction->opcode()) { - case HloOpcode::kConvert: - // If Convert is after AllReduce, it is worth moving out AllReduce out - // of conditional for AR/CRS combine. If Convert is after other ops such - // as Dot or Convolutional, it is better to keep convert within - // conditional so that convert can be fused with Dot or Convolutional. - // - // TODO(b/154283721): figure out the scenario when convert can be fused - // with AllReduce out of conditional. - if (instruction->operand(0)->opcode() == HloOpcode::kAllReduce) { - return true; - } - return false; - case HloOpcode::kAllReduce: - case HloOpcode::kAdd: - case HloOpcode::kConstant: - case HloOpcode::kSubtract: - case HloOpcode::kMultiply: - case HloOpcode::kDivide: - case HloOpcode::kTuple: - case HloOpcode::kSqrt: +// Returns estimation of potential reuses carried by a given pair of +// instructions. Use different integers to classify different levels +// of reuses This is used as a placeholder only, assuming all +// instructions can be fused to enable data reuses +int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) { + VLOG(1) << "ConditionalCodeMotion: Add reuses carried by instr: " + << op->ToString() << "=>" << user->ToString() << "\n"; + switch (user->opcode()) { case HloOpcode::kGetTupleElement: - return true; + case HloOpcode::kTuple: + return 0; default: - return false; + break; + } + switch (op->opcode()) { + // These instructions are lightweight and easy to fuse. + case HloOpcode::kConstant: + case HloOpcode::kGetTupleElement: + return 0; + default: + // Assume fusion will not happen anyway if user count > 1) + if (op->user_count() > 1) { + return 0; + } + return 10; } } @@ -220,7 +139,7 @@ bool InstructionWithinBranchIdentical( return *a == *b; }; - if (instructions[0] == nullptr) { + if (instructions.empty()) { return false; } @@ -248,109 +167,42 @@ bool InstructionWithinBranchIdentical( }); } -// Returns if all the visitors/branches has next instruction to visit. -bool HasNextInstruction(const std::vector& visitors) { - bool has_next = true; - for (const auto& visitor : visitors) { - has_next &= visitor.HasNextInstruction(); - } - return has_next; -} - -// Create tuple element as the new root of the branch. The tuple will contain -// the operands that can't move out of conditional but its user will be moved -// out of conditional. -HloInstruction* CreateNewRoot( - const std::vector& boundaries, - const std::unordered_set& instructions_to_hoist_set, - HloComputation* computation) { - std::vector elements; - elements.reserve(boundaries.size()); - for (auto boundary : boundaries) { - if (ContainsKey(instructions_to_hoist_set, boundary.user)) { - elements.push_back(boundary.operand); +// Copy the ith instruction in boundary to outside of conditional, or do the +// opposite (for moving in). +Status CopyInOrOutOfConditional( + Boundary& boundary, int64 dest_index, HloComputation* parent, + absl::flat_hash_map& hoisted_instructions) { + CHECK(dest_index == 0 || boundary.IsOutsideBranch()); + HloInstruction* op = boundary.operands()[0]; + absl::InlinedVector new_operands; + for (int i = 0; i < op->operands().size(); ++i) { + auto op_i = op->operands()[i]; + VLOG(2) << "Looking for operand:" << op_i->ToString() << "\n"; + if (ContainsKey(hoisted_instructions, op_i)) { + auto new_op_i = + FindOrDie(hoisted_instructions, op_i).operands()[dest_index]; + VLOG(2) << "new operand:" << new_op_i->ToString() << "\n"; + new_operands.push_back(new_op_i); + } else { + CHECK(op_i->opcode() == HloOpcode::kConstant); + auto new_op_i = parent->AddInstruction(op_i->Clone()); + VLOG(2) << "new operand:" << new_op_i->ToString() << "\n"; + new_operands.push_back(new_op_i); } } - return computation->AddInstruction(HloInstruction::CreateTuple(elements)); -} - -// Copy identical instructions within conditional outside of conditional. -void CopyIdenticalInstructionsOutOfConditional( - const std::vector& instructions_to_hoist, - HloComputation* conditional_parent, - absl::flat_hash_map* - hoisted_instructions) { - int64 instructions_size = instructions_to_hoist.size(); - // Visit the operands before its users and copy it, so that the copied - // user will point to the correct operand. - for (int64 i = instructions_size - 1; i >= 0; i--) { - HloInstruction* old_instruction = instructions_to_hoist[i]; - auto get_new_operand = [&](HloInstruction* old_operand) { - // If the operand can't be found in `instructions_to_hoist`, this - // operand will be in the `boundaries`, GetTupleElement instructions - // will be added later to replace this operand. - if (!ContainsKey(*hoisted_instructions, old_operand)) { - return old_operand; - } - return FindOrDie(*hoisted_instructions, old_operand); - }; - - absl::InlinedVector new_operands; - absl::c_transform(old_instruction->operands(), - std::back_inserter(new_operands), get_new_operand); - - HloInstruction* new_instruction = conditional_parent->AddInstruction( - old_instruction->CloneWithNewOperands(old_instruction->shape(), - new_operands)); - // Maps the instruction outside of conditional to the instruction - // inside of the conditional. - InsertOrDie(hoisted_instructions, old_instruction, new_instruction); - } -} - -// If there are instructions to hoist, the root of the conditional must be -// moved out. Change the users of the conditional to the hoisted instruction -// of the new root. -Status ChangeConditionalUsers( - HloInstruction* conditional, HloInstruction* old_root, - const absl::flat_hash_map& - hoisted_instructions) { - HloInstruction* new_root = FindOrDie(hoisted_instructions, old_root); - TF_RETURN_IF_ERROR(conditional->ReplaceAllUsesWith(new_root)); - return Status::OK(); -} - -// Insert GetTupleElement before the instructions whose operands might still -// be within the conditional. -Status CreateGetTupleElementAfterConditional( - const std::vector& boundaries, - const std::unordered_set& instructions_to_hoist_set, - const absl::flat_hash_map& - hoisted_instructions, - HloInstruction* conditional, HloComputation* computation) { - int boundary_instruction_size = boundaries.size(); - - // Inserts GetTupleElement before the boundary instructions. - for (int i = 0; i < boundary_instruction_size; i++) { - HloInstruction* gte = - computation->AddInstruction(HloInstruction::CreateGetTupleElement( - boundaries[i].operand->shape(), conditional, i)); - - HloInstruction* new_instruction = - FindOrDie(hoisted_instructions, boundaries[i].user); - TF_RETURN_IF_ERROR( - new_instruction->ReplaceOperandWith(boundaries[i].operand_index, gte)); - } - return Status::OK(); -} - -// Remove instructions to be hoisted out of the branch computation. -Status RemoveInstructionFromComputation( - const std::vector& instructions_to_hoist, - HloComputation* branch) { - // Will visit the instructions after its users. - for (auto* instruction : instructions_to_hoist) { - TF_RETURN_IF_ERROR(branch->RemoveInstruction(instruction)); + HloInstruction* new_instruction = parent->AddInstruction( + op->CloneWithNewOperands(op->shape(), new_operands)); + VLOG(2) << "new instruction:" << new_instruction->ToString() << "\n"; + // Maps the instruction outside of conditional to the instruction + // inside of the conditional. + for (HloInstruction* op : boundary.operands()) { + Boundary b2 = ContainsKey(hoisted_instructions, op) + ? hoisted_instructions[op] + : Boundary(boundary.IsOutsideBranch() + ? Boundary::Position::kInsideBranch + : Boundary::Position::kOutsideBranch); + b2.mutable_operands().push_back(new_instruction); + hoisted_instructions[op] = b2; } return Status::OK(); } @@ -482,7 +334,7 @@ StatusOr ConvertSpecialMove(HloInstruction* conditional, old_root = conditional->branch_computation(branch)->root_instruction(); absl::flat_hash_map map_inst_to_tuple_index; std::vector new_operands(old_root->operand_count()); - std::unordered_set to_hoist_set; + absl::flat_hash_set to_hoist_set; for (int64 operand_num = 0; operand_num < old_root->operand_count(); ++operand_num) { @@ -574,128 +426,545 @@ StatusOr ConvertSpecialMove(HloInstruction* conditional, // are the shape of the operands are identical and their properties are // identical. Will start from the root instruction of each branch and get // the identical ops to hoist. -StatusOr MergeIdenticalElements(HloInstruction* conditional, - bool is_layout_sensitive) { - VLOG(1) << " visiting conditional:" << conditional->ToString(); - int branch_count = conditional->branch_count(); - if (branch_count <= 0) { +StatusOr ConditionalCodeMotion::MoveInstructionOut( + HloInstruction* conditional, std::vector& to_move_out, + std::vector& new_boundaries) { + if (to_move_out.empty()) { return false; } - - std::vector visitors; - visitors.reserve(branch_count); - // Visit instructions from the root instruction to the operands using BFS. - for (int i = 0; i < branch_count; i++) { - visitors.emplace_back(BranchVisitor(conditional->branch_computation(i))); - } - - // The instructions to be visited within each branch. - std::vector front_instructions(branch_count); - - while (HasNextInstruction(visitors)) { - for (int i = 0; i < branch_count; i++) { - front_instructions[i] = visitors[i].GetNextInstruction(); - } - // If two instructions has the same shape, opcode and its operands has the - // same shape, then this instruction can be moved out of conditional. - if (WorthHoisting(front_instructions[0]) && - InstructionWithinBranchIdentical(front_instructions, - is_layout_sensitive)) { - for (int i = 0; i < branch_count; i++) { - visitors[i].AddInstructionOperands(front_instructions[i]); - visitors[i].AddInstructionToHoist(front_instructions[i]); - } - } else { - for (int i = 0; i < branch_count; i++) { - // If the ops are not identical, these ops and its users will - // be in the boundaries` of the conditional. These ops will be stayed - // within the conditional, but one its only user will be moved out - // of conditional. - visitors[i].AddInstructionToBoundary(front_instructions[i]); - } - } - } - - if (visitors[0].HoistInstructionSize() < 1) { - return false; - } - - HloInstruction* old_root = - conditional->branch_computation(0)->root_instruction(); + VLOG(1) << "number of boundaries to move out:" << to_move_out.size() << "\n"; HloComputation* conditional_parent = conditional->parent(); + // save the old users before add new conditional user instructions + std::vector old_conditional_users = conditional->users(); // Maps instructions in the conditional body to instructions hoisted outside // the conditional that compute the same value. - absl::flat_hash_map hoisted_instructions; - // Copy identical instructions out of the conditional. - CopyIdenticalInstructionsOutOfConditional(visitors[0].instructions_to_hoist(), - conditional_parent, - &hoisted_instructions); - // If there are instructions to hoist, the root of the conditional must be - // moved out. Change the users of the conditional to the hoisted instruction - // of the new root. - TF_RETURN_IF_ERROR( - ChangeConditionalUsers(conditional, old_root, hoisted_instructions)); - - // Create tuple element within each branch and set it as root. - for (int i = 0; i < branch_count; i++) { - HloInstruction* tuple = CreateNewRoot( - visitors[i].boundaries(), visitors[i].instructions_to_hoist_set(), - conditional->branch_computation(i)); - conditional->branch_computation(i)->set_root_instruction(tuple, true); - } - // Changes conditional instruction shape to the shape of the new root. - *conditional->mutable_shape() = - conditional->branch_computation(0)->root_instruction()->shape(); - + absl::flat_hash_map hoisted_instructions; // Insert GetTupleElement before the instructions whose operands might still // be within the conditional. - TF_RETURN_IF_ERROR(CreateGetTupleElementAfterConditional( - visitors[0].boundaries(), visitors[0].instructions_to_hoist_set(), - hoisted_instructions, conditional, conditional_parent)); - - // Remove hoist instructions from the branches. - for (int i = 0; i < branch_count; i++) { - TF_RETURN_IF_ERROR( - RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(), - conditional->branch_computation(i))); + VLOG(2) << "before opt:" + << conditional_parent->ToString(HloPrintOptions::Fingerprint()) + << "\n"; + int64 op_index = 0; + for (Boundary b : new_boundaries) { + HloInstruction* op = b.operands()[0]; + CHECK(op != nullptr); + VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n"; + HloInstruction* gtr = conditional_parent->AddInstruction( + HloInstruction::CreateGetTupleElement(op->shape(), conditional, + op_index++)); + Boundary b2(Boundary::Position::kOutsideBranch); + b2.mutable_operands().push_back(gtr); + hoisted_instructions[op] = b2; } + // Copy boundary instructions out of the conditional. + // Visit the operands before its users and copy it, so that the copied + // user will point to the correct operand. + for (int64 i = to_move_out.size() - 1; i >= 0; i--) { + TF_RETURN_IF_ERROR(CopyInOrOutOfConditional( + to_move_out[i], 0, conditional_parent, hoisted_instructions)); + } + VLOG(2) << "Done copy branch instructions out\n" + << conditional_parent->ToString(HloPrintOptions::Fingerprint()) + << "\n"; + // Change original users of the conditional to use the correct operands. + HloInstruction* old_root = + conditional->branch_computation(0)->root_instruction(); + for (auto user_instr : old_conditional_users) { + CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement); + auto tuple_opd = static_cast(user_instr); + int64 index = tuple_opd->tuple_index(); + HloInstruction* old_opd = old_root->operands()[index]; + HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0]; + CHECK(old_opd != nullptr); + CHECK(new_opd != nullptr); + TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd)); + TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr)); + } + // Create tuple element within each branch and set it as root. + int64 branch_count = conditional->branch_count(); + for (int i = 0; i < branch_count; i++) { + auto computation = conditional->branch_computation(i); + std::vector elements; + for (auto b1 : new_boundaries) { + HloInstruction* op = b1.operands()[i]; + VLOG(1) << "branch count=" << i << "\n"; + CHECK(op != nullptr); + VLOG(1) << "Adding to root " << i << " with " << op->ToString() << "\n"; + elements.push_back(op); + } + HloInstruction* tuple = + computation->AddInstruction(HloInstruction::CreateTuple(elements)); + computation->set_root_instruction(tuple, true); + VLOG(2) << "computation is :" << computation->ToString() << "\n"; + // Remove hoisted instructions from the branches. + for (auto b2 : to_move_out) { + VLOG(2) << "Removing boundary:" << b2.ToString() << "\n"; + TF_RETURN_IF_ERROR(computation->RemoveInstruction(b2.operands()[i])); + } + } + // Change conditional instruction shape to the shape of the new root. + HloInstruction* new_root = + conditional->branch_computation(0)->root_instruction(); + *conditional->mutable_shape() = new_root->shape(); + // + VLOG(2) << "done moving instructions out of branches\n" + << conditional_parent->ToString(HloPrintOptions::Fingerprint()) + << "\n"; return true; } -} // namespace +// Hoist ops from outside of the conditional to inside the branches. +StatusOr ConditionalCodeMotion::MoveInstructionIn( + HloInstruction* conditional, std::vector& to_move_in, + std::vector& new_boundaries) { + if (to_move_in.empty()) { + return false; + } + VLOG(1) << "number of boundaries to move in:" << to_move_in.size() << "\n"; + HloComputation* conditional_parent = conditional->parent(); + VLOG(2) << "before opt:" + << conditional_parent->ToString(HloPrintOptions::Fingerprint()) + << "\n"; + // Mapping instructions to be moved to their new representations. + absl::flat_hash_map hoisted_instructions; + int64 to_move_in_size = to_move_in.size(); + int64 branch_count = conditional->branch_count(); + int64 op_index = conditional->shape().tuple_shapes_size(); + // Map conditional to its old root, then create a new root instruction in each + // branch. + Boundary b(Boundary::Position::kInsideBranch); + for (int i = 0; i < branch_count; i++) { + auto computation = conditional->branch_computation(i); + auto old_root = computation->root_instruction(); + b.mutable_operands().push_back(old_root); + HloInstruction* new_root = nullptr; + if (old_root->opcode() == HloOpcode::kTuple) { + new_root = computation->AddInstruction(old_root->Clone()); + } else { + std::vector operands; + if (!old_root->shape().IsTuple()) { + operands.push_back(old_root); + } else { + const Shape& old_shape = old_root->shape(); + for (int64 i = 0; i < old_shape.tuple_shapes_size(); ++i) { + auto element = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + old_shape.tuple_shapes(i), old_root, i)); + operands.push_back(element); + } + } + new_root = + computation->AddInstruction(HloInstruction::CreateTuple(operands)); + } + VLOG(2) << "setting new root: " << new_root->ToString() << "\n"; + computation->set_root_instruction(new_root); + VLOG(2) << "new branch computation: " << computation->ToString() << "\n"; + } + hoisted_instructions[conditional] = b; + for (int64 i = 0; i < to_move_in_size; i++) { + Boundary b_to_move = to_move_in[i]; + HloInstruction* op = b_to_move.operands()[0]; + CHECK(op != nullptr); + bool to_be_used_outside = true; + VLOG(2) << "Mapping new boundary instr: " << op->ToString() << "\n"; + if (i < to_move_in_size - 1 && op->user_count() == 1 && + op->users()[0] == to_move_in[i + 1].operands()[0]) { + to_be_used_outside = false; + VLOG(2) << "Instruction is not to be used outside the branch\n"; + } + Boundary b(Boundary::Position::kInsideBranch); + for (int i = 0; i < branch_count; i++) { + auto computation = conditional->branch_computation(i); + TF_RETURN_IF_ERROR(CopyInOrOutOfConditional(b_to_move, i, computation, + hoisted_instructions)); + VLOG(2) << "After Copying to branch: " << computation->ToString() << "\n"; + if (to_be_used_outside) { + auto new_op = hoisted_instructions[op].operands()[i]; + auto new_root = computation->root_instruction(); + new_root->AppendOperand(new_op); + *new_root->mutable_shape()->add_tuple_shapes() = new_op->shape(); + VLOG(2) << "Extending conditional root " << i << " : " + << new_root->ToString() << "\n"; + } + VLOG(2) << "After extending branch root: " << computation->ToString() + << "\n"; + } + if (to_be_used_outside) { + // Modify uses of instructions outside of the conditionals + HloInstruction* gtr = conditional_parent->AddInstruction( + HloInstruction::CreateGetTupleElement(op->shape(), conditional, + op_index++)); + TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr)); + if (conditional_parent->root_instruction() == op) { + conditional_parent->set_root_instruction(gtr); + } + } + } + VLOG(2) << "Done copying instructions inside branch: " + << conditional->ToString(HloPrintOptions::Fingerprint()) << "\n"; + // Change conditional instruction shape to the shape of the new root. + HloInstruction* new_root = + conditional->branch_computation(0)->root_instruction(); + *conditional->mutable_shape() = new_root->shape(); + VLOG(2) << "Before removing instructions:" << conditional_parent->ToString() + << "\n"; + // Remove hoisted instructions from the branches. + for (int64 i = to_move_in_size - 1; i >= 0; i--) { + Boundary boundary_to_move_in = to_move_in[i]; + VLOG(2) << "Removing boundary:" << boundary_to_move_in.ToString() << "\n"; + HloInstruction* op = boundary_to_move_in.operands()[0]; + for (auto user : op->users()) { + VLOG(2) << "Has User: " << user->ToString() << "\n"; + } + TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(op)); + } + VLOG(2) << "Done moving instructions inside branches\n" + << conditional_parent->ToString(HloPrintOptions::Fingerprint()) + << "\n"; + return true; +} -StatusOr ConditionalCodeMotion::Run(HloModule* module) { - bool changed = false; +// Group single chains of operands or uses of boundaries into new boundaries +class GroupConnectedBoundaries { + private: + std::vector connected_boundaries_, new_boundaries_; + HloInstruction* conditional_; + HloComputation* conditional_parent_; + bool is_layout_sensitive_; + absl::flat_hash_set visited_; - if (pursue_full_conditional_code_motion_) { - std::vector conditional_ops; - for (auto* comp : module->MakeComputationPostOrder()) { - for (auto* instr : comp->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kConditional) { - conditional_ops.push_back(instr); + public: + explicit GroupConnectedBoundaries(HloInstruction* conditional, + bool is_layout_sensitive) + : conditional_(conditional), + conditional_parent_(conditional->parent()), + is_layout_sensitive_(is_layout_sensitive) {} + // Returns true if `instruction` is worth hoisting out. + bool WorthHoisting(HloInstruction* instruction) { + // This is needed for the "moving-in" transformation, to prevent the root + // of the parent computation (which contains the conditional) to be moved + // inside the conditional. + if (instruction->opcode() == HloOpcode::kTuple && + instruction == conditional_parent_->root_instruction()) { + return false; + } + switch (instruction->opcode()) { + case HloOpcode::kConvert: + // If Convert is after AllReduce, it is worth moving out AllReduce + // out of conditional for AR/CRS combine. If Convert is after other + // ops such as Dot or Convolutional, it is better to keep convert + // within conditional so that convert can be fused with Dot or + // Convolutional. + // + // TODO(b/154283721): figure out the scenario when convert can be + // fused with AllReduce out of conditional. + switch (instruction->operand(0)->opcode()) { + case HloOpcode::kAllReduce: + case HloOpcode::kReshape: + return true; + default: + VLOG(1) << "Instruction is convert and its operand is not know to " + "be worth hoisting\n"; + return false; + } + case HloOpcode::kAllReduce: + case HloOpcode::kAdd: + case HloOpcode::kPower: + case HloOpcode::kConstant: + case HloOpcode::kSubtract: + case HloOpcode::kMultiply: + case HloOpcode::kDivide: + case HloOpcode::kTuple: + case HloOpcode::kSqrt: + case HloOpcode::kReshape: + case HloOpcode::kGetTupleElement: + return true; + default: + VLOG(1) << "Instruction is not known to be worth hoisting\n"; + return false; + } + } + int64 ReusesBeforeBoundary(HloInstruction* user) { + int64 reuses = 0; + for (auto op : user->operands()) { + // Only consider single-user cases as reuseable. + if (ContainsKey(visited_, op) && op->user_count() == 1) { + reuses += ReusesCarriedBy(op, user); + } else if (op->opcode() == HloOpcode::kConditional && + user->opcode() == HloOpcode::kGetTupleElement) { + if (user->user_count() == 1) { + reuses += ReusesCarriedBy(op, user->users()[0]); } } } + VLOG(1) << "Reuses before instruction " << user->ToString() << ":" << reuses + << "\n"; + return reuses; + } - for (HloInstruction* conditional_op : conditional_ops) { - TF_ASSIGN_OR_RETURN( - bool result, - MergeIdenticalElements(conditional_op, is_layout_sensitive_)); - changed |= result; + int64 ReusesAfterBoundary(HloInstruction* user) { + CHECK(user != nullptr); + auto all_users = user->users(); + // For now, assume that if an instruction has multiple-consumers, it + // will not be reused, as the reuse may require duplication in + // fusion and so is expensive. If the situation changes in the future, + // some aspects of the overall algorithm need to be redesigned to + // accommandate the change. + if (all_users.size() > 1) { + return 0; } + if (!all_users.empty()) { + auto op = all_users[0]; + int64 reuses = 0; + // Only count reuses that run through the conditional root. + if (op == conditional_->branch_computation(0)->root_instruction()) { + int64 index = op->operand_index(user); + for (auto op2 : conditional_->users()) { + // If the use is not get tuple, right now do not consider it. + if (op2->opcode() == HloOpcode::kGetTupleElement) { + auto tuple_opd = static_cast(op2); + if (index == tuple_opd->tuple_index()) { + all_users = op2->users(); + if (!all_users.empty()) { + reuses += ReusesCarriedBy(user, all_users[0]); + break; + } + } + } + } + } else if (ContainsKey(visited_, op)) { + reuses += ReusesCarriedBy(user, op); + } + VLOG(1) << "reuses after instruction " << user->ToString() << ":" + << reuses << "\n"; + return reuses; + } + return 0; + } - if (changed) { - HloPassPipeline subpipeline("after_conditional_code_motion"); - subpipeline.AddPass(); - subpipeline.AddPass(); - subpipeline.AddPass(); - TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); - changed |= cleanup_changed; + int64 BenefitForMovingBoundaries(const std::vector& boundaries) { + int64 reuses_before = 0, reuses_after = 0; + if (boundaries.size() == 1 && boundaries[0].IsOutsideBranch()) { + // The only boundary of moving-in is the get_tuple_element op. + return -1; + } + for (Boundary b : boundaries) { + auto op = b.operands()[0]; + if (op == conditional_->branch_computation(0)->root_instruction()) { + continue; + } + reuses_before += ReusesBeforeBoundary(op); + VLOG(1) << "Reuses before boundary so far: " << reuses_before << "\n"; + reuses_after += ReusesAfterBoundary(op); + VLOG(1) << "Reuese after boundary so far : " << reuses_after << "\n"; + } + if (reuses_after == 0 && reuses_before == 0) { + return -1; + } else if (boundaries[0].IsInsideBranch()) { + return reuses_after - reuses_before; + } else { + return reuses_before - reuses_after; } } + Boundary GetNextBoundary(const Boundary& b, int64 op_index) { + Boundary b2(b.GetPosition()); + for (int j = 0; j < b.operands().size(); ++j) { + HloInstruction* inst = b.operands()[j]; + CHECK(inst != nullptr); + HloInstruction* op = (b.IsInsideBranch()) ? inst->operands()[op_index] + : inst->users()[op_index]; + CHECK(op != nullptr); + b2.mutable_operands().push_back(op); + } + return b2; + } + int64 CountNonLeafOps(const xla::HloInstruction::InstructionVector& ops) { + int64 count = 0; + absl::flat_hash_set op_set; + for (auto op : ops) { + if (!op_set.contains(op) && op->opcode() != HloOpcode::kConstant) { + count++; + op_set.insert(op); + } + } + return count; + } + // This function is reused both for moving the boundary outside or into a + // conditional. As the result, the readability is somewhat compromised. + // It might be nice to refactor this function to factor the outside-inside + // considerations into separate function pointer parameters to improve + // readability. + void AddBoundaries(const Boundary& boundary) { + BoundaryVisitor visitor; + visitor.AddToWorkList(boundary); + while (visitor.HasNextBoundary()) { + Boundary b = visitor.PopNextBoundary(); + VLOG(1) << "visiting boundary " << b.ToString() << "\n"; + if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical( + b.operands(), is_layout_sensitive_)) && + WorthHoisting(b.operands()[0])) { + connected_boundaries_.push_back(b); + VLOG(1) << "boundary can be moved\n"; + int64 operand_count = (b.IsInsideBranch()) + ? b.operands()[0]->operand_count() + : b.operands()[0]->users().size(); + for (int i = 0; i < operand_count; i++) { + Boundary next_boundary = GetNextBoundary(b, i); + int64 next_boundary_count = + (next_boundary.IsInsideBranch()) + ? next_boundary.operands()[0]->user_count() + : CountNonLeafOps(next_boundary.operands()[0]->operands()); + // only consider adding an exclusive producor into the same group. + if (next_boundary_count == 1) { + VLOG(2) << "Add operand " << i << " to visit later\n"; + visitor.AddToWorkList(next_boundary); + } else { + VLOG(2) << "Next boundary " << i + << " has multiple uses: " << next_boundary_count << "\n"; + if (!ContainsKey(visited_, next_boundary.operands()[0])) { + visited_.insert(next_boundary.operands()[0]); + new_boundaries_.push_back(next_boundary); + } + } + } + } else { + VLOG(1) << "boundary cannot be moved\n"; + visited_.insert(b.operands()[0]); + new_boundaries_.push_back(b); + } + } + } + std::vector BoundariesToMoveInOrOut(const Boundary& b) { + // At the beginning of optimization, a conditional itself is added to a + // worklist. Here the conditional is expanded into two sets of boundaries: + // the first set contains the boundary that is inside branches and + // contains the root of all branches; the second set of boundaries + // contains all the users of the conditional. + HloInstruction* inst = b.operands()[0]; + if (inst->opcode() == HloOpcode::kConditional) { + int branch_count = inst->branch_count(); + // Add conditional roots as a new boundary to visit. + Boundary boundary_in(Boundary::Position::kInsideBranch); + for (int i = 0; i < branch_count; i++) { + HloComputation* branch_computation = inst->branch_computation(i); + HloInstruction* root_inst = branch_computation->root_instruction(); + CHECK(root_inst != nullptr); + boundary_in.mutable_operands().push_back(root_inst); + } + new_boundaries_.push_back(boundary_in); + // Add conditional users as new boundaries to visit. + for (auto u : inst->users()) { + Boundary boundary_in(Boundary::Position::kOutsideBranch); + boundary_in.mutable_operands().push_back(u); + new_boundaries_.push_back(boundary_in); + } + } else { + AddBoundaries(b); + } + return connected_boundaries_; + } + void AddNewBoundaries(std::vector& b) { + b.insert(b.end(), new_boundaries_.begin(), new_boundaries_.end()); + } +}; + +ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( + HloInstruction* conditional, const Boundary& cur_boundary, + std::vector& to_move, std::vector& new_boundaries) { + GroupConnectedBoundaries connect(conditional, is_layout_sensitive_); + auto move_in_or_out = connect.BoundariesToMoveInOrOut(cur_boundary); + if (!move_in_or_out.empty()) { + auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out); + VLOG(1) << "benefit of moving in or out " + << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n"; + if (benefit >= 0) { + new_boundaries.clear(); + connect.AddNewBoundaries(new_boundaries); + // The whole sequence in move_in_or_out is either all moving into a + // conditional, or all moving out of a conditional. So looking only + // at the first entry of the sequence is sufficient to know which + // direction the move is intended. + to_move = move_in_or_out; + return to_move[0].IsInsideBranch() ? Decision::kMoveOutOfBranch + : Decision::kMoveIntoBranch; + } + } else { + connect.AddNewBoundaries(new_boundaries); + } + return ConditionalCodeMotion::Decision::kNoChange; +} + +StatusOr ConditionalCodeMotion::Run(HloModule* module) { + // Gather all the conditional ops in the module ahead of time, to avoid + // potential complications of modifying the code that affecting traversal. + std::vector conditional_ops; + for (auto* comp : module->MakeComputationPostOrder()) { + for (auto* instr : comp->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kConditional) { + conditional_ops.push_back(instr); + } + } + } + + bool changed = false; + for (HloInstruction* conditional : conditional_ops) { + // Boundaries to move out or to move into the branches. + std::vector to_move_out, to_move_in, new_boundaries; + // The conditional is moved into a worklist as the seed (starting point). + // The conditional will be expanded into multiple seeds (starting points), + // its roots and its users, when it is visited by GroupConnectedBoundaries. + // A NO_CHANGE decision will always be returned for the conditional itself, + // so that the other seeding boundaries can be visited in turn. + BoundaryVisitor visitor(conditional); + VLOG(2) << "Analyzing conditional:" << conditional->ToString() << "\n"; + ConditionalCodeMotion::Decision d = Decision::kNoChange; + // The following loop breaks out as soon as a decision to modify the + // conditional is reached --- irrespective of whether visitor is empty. + while (d == Decision::kNoChange && visitor.HasNextBoundary()) { + std::vector to_move, next_boundary; + Boundary boundary = visitor.PopNextBoundary(); + VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n"; + d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary); + switch (d) { + case Decision::kMoveOutOfBranch: + VLOG(2) << "Decision is move out of branch\n"; + to_move_out.insert(to_move_out.end(), to_move.begin(), to_move.end()); + new_boundaries.insert(new_boundaries.end(), next_boundary.begin(), + next_boundary.end()); + break; + case Decision::kMoveIntoBranch: + VLOG(2) << "Decision is move into branch\n"; + to_move_in.insert(to_move_in.end(), to_move.begin(), to_move.end()); + new_boundaries.insert(new_boundaries.end(), next_boundary.begin(), + next_boundary.end()); + break; + case Decision::kNoChange: + VLOG(2) << "Decision is no change\n"; + for (const Boundary& b : next_boundary) { + visitor.AddToWorkList(b); + } + break; + } + } + // At most one of to_move_out or to_move_in can be non-empty, since there is + // only one optimization decision. + if (!to_move_out.empty()) { + TF_ASSIGN_OR_RETURN( + bool result, + MoveInstructionOut(conditional, to_move_out, new_boundaries)); + VLOG(2) << "moving out result:" << result << "\n"; + changed |= result; + } else if (!to_move_in.empty()) { + TF_ASSIGN_OR_RETURN( + bool result, + MoveInstructionIn(conditional, to_move_in, new_boundaries)); + VLOG(2) << "moving in result:" << result << "\n"; + changed |= result; + } + } // handling convert rematerialization/hoisting - { + if (!changed && pursue_full_conditional_code_motion_) { std::vector conditional_ops; for (auto* comp : module->MakeComputationPostOrder()) { for (auto* instr : comp->MakeInstructionPostOrder()) { @@ -711,7 +980,6 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { changed |= convert_result; } } - if (changed) { HloPassPipeline subpipeline( "after_conditional_code_motion_after_convert_hoisting"); @@ -721,8 +989,8 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); changed |= cleanup_changed; } - return changed; } +} // namespace conditional_opt } // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.h b/tensorflow/compiler/xla/service/conditional_code_motion.h index 95f02833e15..68a2aa58235 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.h +++ b/tensorflow/compiler/xla/service/conditional_code_motion.h @@ -23,35 +23,81 @@ limitations under the License. namespace xla { -// ConditionalCodeMotion specializes in hoisting/rematerializing -// unconditional converts in the default mode. -// When pursue_full_conditional_code_motion_ is set to true, the -// full HLO pass moves identical ops out of a conditional in addition to moving -// converts. +namespace conditional_opt { +// At the conceptual level, a boundary can be thought of as representing a +// single virtual operation, except this virtual operation is conditionally +// instantiated into different concrete operations at each conditional branch. +// So a boundary is mapped to a single concrete operation if it is outside of +// conditional branches, and is mapped to a list of instructions if inside the +// branches. This data structure therefore allows a common data structure +// representation of the instructions to be moved, whether they are inside or +// outside of the branches. Subsequently, it allows a common implementation +// basis to be used for both moving instructions out of and for moving them +// inside branches. +class Boundary { + public: + enum class Position { kInsideBranch, kOutsideBranch, kUndefined }; + Boundary() : position_(Position::kUndefined) {} + explicit Boundary(Position p) : position_(p) {} + std::vector& mutable_operands() { return operands_; } + const std::vector& operands() const { return operands_; } + bool IsInsideBranch() const { return position_ == Position::kInsideBranch; } + bool IsOutsideBranch() const { return position_ == Position::kOutsideBranch; } + Position GetPosition() const { return position_; } + bool IsEmpty() const { return operands_.empty(); } + std::string ToString() const { + std::string res; + for (HloInstruction* op : operands_) { + res += op->ToString() + ";"; + } + return res; + } + + private: + // Boundary instructions in the conditional branches, one from each branch + // of the conditional; or a single operand from outside the conditional. + std::vector operands_; + Position position_; +}; + +// HLO pass that moves identical ops in/out of conditional. // - The definition of identical are the shape of the operands are identical // and their properties are identical. -// - Currently, only some types of instructions is supported. -// TODO(b/154283721): relax non-sharable operand constraint and avoid copies in -// the new root. // - Only the identical ops that won't share operands with other ops will // be moved out of conditional. class ConditionalCodeMotion : public HloModulePass { public: // If is_layout_sensitive is true, then the hoist process preserves layout // during identical comparison. Otherwise, layout is ignored. - explicit ConditionalCodeMotion( - bool is_layout_sensitive = true, - bool pursue_full_conditional_code_motion = false) + explicit ConditionalCodeMotion(bool is_layout_sensitive, + bool pursue_full_conditional_code_motion) : is_layout_sensitive_(is_layout_sensitive), pursue_full_conditional_code_motion_( pursue_full_conditional_code_motion) {} absl::string_view name() const override { return "conditional-code-motion"; } StatusOr Run(HloModule* module) override; + // Optimization decision for each boundary of the conditional instruction. + enum class Decision { kMoveOutOfBranch, kMoveIntoBranch, kNoChange }; + // If the optimization decision is NO_CHANGE, new_boundary is set to nullptr; + // otherwise, it is set to the new boundary after proposed optimization. + virtual Decision ConsiderCodeMotion(HloInstruction* conditional, + const Boundary& cur_boundary, + std::vector& to_move, + std::vector& new_boundaries); + private: const bool is_layout_sensitive_; const bool pursue_full_conditional_code_motion_; + + StatusOr MoveInstructionOut(HloInstruction* conditional, + std::vector& to_move_out, + std::vector& new_boundaries); + StatusOr MoveInstructionIn(HloInstruction* conditional, + std::vector& to_move_in, + std::vector& new_boundaries); }; +} // namespace conditional_opt } // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc index 38b2b515fa0..b0a6ba92f48 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace { +namespace conditional_opt { using ConditionalCodeMotionTest = HloTestBase; namespace op = xla::testing::opcode_matchers; @@ -117,6 +117,47 @@ ENTRY main { EXPECT_THAT(root, AllOf(op::Tuple(op::Convert()))); } +TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditional) { + absl::string_view hlo_string = + R"( +HloModule RemoveDotOpOut + +on_true { + %arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0 + %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1) + %add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493) + %convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493) + ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894) +} + +on_false { + %arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0 + %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3) + %add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717) + %sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717) + %convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"} + ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1) + arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2) + conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false + get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0 + ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Tuple(op::Convert()))); +} + TEST_F(ConditionalCodeMotionTest, MoveConvertOut) { absl::string_view hlo_string = R"( @@ -152,8 +193,20 @@ ENTRY main { ConditionalCodeMotion pass(true, true); ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 2); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 2); + HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert())))); + EXPECT_THAT( + root, + AllOf(op::Tuple(op::Add(op::Convert(op::Reshape(op::GetTupleElement( + op::GetTupleElement(op::Conditional())))), + op::Convert(op::Reshape(op::GetTupleElement( + op::GetTupleElement(op::Conditional())))))))); } TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) { @@ -173,7 +226,7 @@ on_true { add.2 = f32[] add(add.1, constant.2) add.3 = f32[] add(add.1, constant.3) add.4 = f32[] add(add.3, constant.5) - multiply.1 = f32[] multiply(add.2, constant.4) + multiply.1 = f32[] multiply(add.4, constant.4) ROOT tuple.6 = (f32[], f32[]) tuple(multiply.1, add.4) } @@ -202,7 +255,7 @@ ENTRY main { false_computation=on_false get-first-index = f32[] get-tuple-element(conditional), index=0 get-second-index = f32[] get-tuple-element(conditional), index=1 - ROOT result = (f32[], f32[]) tuple(get-first-index, get-second-index) + ROOT result = f32[] add(get-first-index, get-second-index) } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); @@ -216,13 +269,11 @@ ENTRY main { const HloComputation* on_false = conditional->branch_computation(1); ASSERT_EQ(on_false->instruction_count(), 9); - // Check only one add and multiply is moved out. HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT( - root, - AllOf(op::Tuple( - op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()), - op::Add(op::GetTupleElement(op::Conditional()), op::Constant())))); + EXPECT_THAT(root, + AllOf(op::Add(op::Multiply(op::GetTupleElement(op::Conditional()), + op::Constant()), + op::GetTupleElement(op::Conditional())))); } TEST_F(ConditionalCodeMotionTest, ConditionalRootElementChanged) { @@ -260,7 +311,7 @@ ENTRY main { conditional(pred.1, tuple.1, tuple.2), true_computation=on_true, false_computation=on_false get-first-index = f32[] get-tuple-element(conditional), index=0 - ROOT result = (f32[]) tuple(get-first-index) + ROOT result = f32[] add(get-first-index, get-first-index) } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); @@ -269,16 +320,21 @@ ENTRY main { const HloInstruction* conditional = FindInstruction(module.get(), "conditional"); const HloComputation* on_true = conditional->branch_computation(0); - ASSERT_EQ(on_true->instruction_count(), 7); + ASSERT_EQ(on_true->instruction_count(), 1); const HloComputation* on_false = conditional->branch_computation(1); - ASSERT_EQ(on_false->instruction_count(), 7); + ASSERT_EQ(on_false->instruction_count(), 1); - // add.3 in on_true will be moved out, add.1 and add.2 will be in condtional - // root. - ASSERT_TRUE(ShapeUtil::Compatible( - conditional->shape(), - ShapeUtil::MakeTupleShape( - {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}))); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Add( + op::Add( + op::Add(op::GetTupleElement(op::Conditional()), op::Constant()), + op::Add(op::GetTupleElement(op::Conditional()), op::Constant())), + op::Add( + op::Add(op::GetTupleElement(op::Conditional()), op::Constant()), + op::Add(op::GetTupleElement(op::Conditional()), + op::Constant()))))); } TEST_F(ConditionalCodeMotionTest, ConditionalIsRootInstruction) { @@ -329,24 +385,9 @@ ENTRY main { )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); ConditionalCodeMotion pass(true, true); - ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); - - const HloInstruction* conditional = - FindInstruction(module.get(), "conditional"); - const HloComputation* on_true = conditional->branch_computation(0); - ASSERT_EQ(on_true->instruction_count(), 9); - const HloComputation* on_false = conditional->branch_computation(1); - ASSERT_EQ(on_false->instruction_count(), 9); - - // Check only one add and multiply is moved out. - // add.3 and add.5 can't be moved out because they share operands with - // other instructions. - HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT( - root, - AllOf(op::Tuple( - op::Multiply(op::GetTupleElement(op::Conditional()), op::Constant()), - op::Add(op::GetTupleElement(op::Conditional()), op::Constant())))); + // If there is no instruction after the conditional, there is no benefit to + // move + ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); } TEST_F(ConditionalCodeMotionTest, LayoutMisMatchCannotMovedOut) { @@ -469,7 +510,8 @@ ENTRY main { false_computation=on_false get-first-index = f32[3,3,128,128] get-tuple-element(conditional), index=0 - ROOT result = (f32[3,3,128,128]) tuple(get-first-index) + add.1 = f32[3,3,128,128] add(f32[3,3,128,128] get-first-index, f32[3,3,128,128] get-first-index) + ROOT result = (f32[3,3,128,128]) tuple(add.1) } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); @@ -487,10 +529,57 @@ ENTRY main { conditional->shape(), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape( BF16, {3, 3, 128, 128})}))); HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(op::AllReduce( - op::GetTupleElement(op::Conditional())))))); + EXPECT_THAT( + root, + AllOf(op::Tuple(op::Add( + op::Convert(op::AllReduce(op::GetTupleElement(op::Conditional()))), + op::Convert( + op::AllReduce(op::GetTupleElement(op::Conditional()))))))); } -} // namespace +TEST_F(ConditionalCodeMotionTest, MovePowOpIn) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +on_true { + arg_tuple.1 = (f32[10]) parameter(0) + get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0 + add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1) + ROOT tuple.3 = (f32[10]) tuple(add.1) +} + +on_false { + arg_tuple.2 = (f32[10]) parameter(0) + get-tuple-element.2 = f32[10] get-tuple-element(arg_tuple.2), index=0 + mul.1 = f32[10] multiply(get-tuple-element.2, get-tuple-element.2) + ROOT tuple.4 = (f32[10]) tuple(mul.1) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[10]) parameter(1) + tuple.2 = (f32[10]) parameter(2) + conditional = (f32[10]) + conditional(pred.1, tuple.1, tuple.2), true_computation=on_true, + false_computation=on_false + get-first-index = f32[10] get-tuple-element(conditional), index=0 + ROOT pow.1 = f32[10] power(get-first-index, get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 5); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); +} +} // namespace conditional_opt } // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index bb19a63a9ce..199bc787b83 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -41,6 +41,26 @@ limitations under the License. namespace xla { namespace { + +// A computation with array type that only contains parameters and tuples is +// considered emtpy. +bool ComputationIsEmptyWithArrayRoot(const HloComputation* computation) { + bool empty_operations = absl::c_all_of( + computation->MakeInstructionPostOrder(), [](const HloInstruction* inst) { + return inst->opcode() == HloOpcode::kTuple || + inst->opcode() == HloOpcode::kGetTupleElement || + inst->opcode() == HloOpcode::kParameter; + }); + bool contains_array = false; + ShapeUtil::ForEachSubshape(computation->root_instruction()->shape(), + [&](const Shape& shape, const ShapeIndex& index) { + if (shape.IsArray()) { + contains_array = true; + } + }); + return empty_operations && contains_array; +} + // Tries to replace a conditional with a call operation of the corresponding // computation. If the given conditional has a constant branch_index, tries to // replace it with a call to its corresponding branch computation and then @@ -124,7 +144,6 @@ StatusOr TryRemoveConditional(HloInstruction* conditional) { << conditional->ToShortString(); return false; } - HloInstruction* true_call_op = create_call(0); HloInstruction* false_call_op = create_call(1); auto condition_broadcast = [&](const Shape& shape) { @@ -140,6 +159,14 @@ StatusOr TryRemoveConditional(HloInstruction* conditional) { return computation->AddInstruction(HloInstruction::CreateGetTupleElement( hlo->shape().tuple_shapes(i), hlo, i)); }; + + bool branch_empty = + ComputationIsEmptyWithArrayRoot(conditional->branch_computation(0)) || + ComputationIsEmptyWithArrayRoot(conditional->branch_computation(1)); + // Empty branch is faster to execute than select. + if (branch_empty) { + return false; + } std::function select = [&](HloInstruction* t, HloInstruction* f) { if (f->shape().IsToken()) { @@ -559,6 +586,10 @@ StatusOr ConditionalSimplifier::Run(HloModule* module) { absl::flat_hash_set removed_conditionals; for (HloInstruction* conditional_op : conditional_ops) { + if (conditional_op->has_sharding()) { + // The code below doesn't handle sharding properly. + continue; + } changed |= MergeDuplicateTupleElements(conditional_op); changed |= RemoveUnusedTupleElements(conditional_op); changed |= ReplaceRootWithEmptyTupleIfNoUsers(conditional_op); @@ -573,18 +604,27 @@ StatusOr ConditionalSimplifier::Run(HloModule* module) { // lets collect them first. absl::flat_hash_map> calling_conditionals; + // Keys of calling_conditionals to get a deterministic ordering. + std::vector calling_computationals_vector; for (HloInstruction* conditional : conditional_ops) { if (removed_conditionals.contains(conditional)) { continue; } + for (int64 branch = 0; branch < conditional->branch_count(); ++branch) { - calling_conditionals[conditional->branch_computation(branch)].insert( - conditional); + auto* branch_comp = conditional->branch_computation(branch); + if (!calling_conditionals.contains(branch_comp)) { + calling_computationals_vector.push_back(branch_comp); + } + calling_conditionals[branch_comp].insert(conditional); } } - for (const auto& entry : calling_conditionals) { + + for (auto* comp : calling_computationals_vector) { + auto entry = calling_conditionals.find(comp); + CHECK(entry != calling_conditionals.end()); TF_ASSIGN_OR_RETURN(bool result, TryRemoveUnusedConditionalOperands( - entry.first, entry.second)); + entry->first, entry->second)); changed |= result; } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 6bfd8c4db46..b88120d8128 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -191,6 +191,30 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, return any_copies; } +// Compute the indices of the conditional outputs which need copies. Umambiguous +// buffers(buffer with only one value) don't need copies. +bool IndicesToCopyForConditional(const HloDataflowAnalysis& dataflow, + const HloInstruction* xla_conditional, + ShapeTree* indices_to_copy) { + DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), + xla_conditional->shape())); + + bool any_copies = false; + for (auto& pair : *indices_to_copy) { + const ShapeIndex& index = pair.first; + bool& should_copy = pair.second; + + CHECK_EQ(dataflow.GetValueSet(xla_conditional, index).values().size(), 1); + + auto value = dataflow.GetValueSet(xla_conditional, index).values()[0]; + // The conditional must be copied if the value is a phi. + should_copy = + value->is_phi() && value->defining_instruction() == xla_conditional; + any_copies |= should_copy; + } + return any_copies; +} + // Add kCopy instructions around the given kWhile instruction to eliminate any // possible live range interference of HLO values assuming a dependency-based // ordering (HloDependencyOrdering). Copies are added conservatively. There @@ -306,24 +330,30 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, } body->set_root_instruction(root_copy); - return Status::OK(); } -// We add copies for all the indices of the true and false computation roots, in -// order to resolve interference. We later rely on RemoveUnnecessaryCopies to -// drop the unnecessary ones. +// We add copies for all non-phi indices of the true and false computation +// roots, in order to resolve interference. We later rely on +// RemoveUnnecessaryCopies to drop the unnecessary ones. Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) { VLOG(2) << "Adding copies for kConditional instruction " << conditional->name(); + ShapeTree indices_to_copy(conditional->shape()); TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional); - + if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(), + conditional, &indices_to_copy)) { + VLOG(2) << "No copies necessary for kWhile instruction " + << conditional->name(); + return Status::OK(); + } for (HloComputation* computation : conditional->branch_computations()) { HloInstruction* root = computation->root_instruction(); std::vector users = root->users(); - TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, - computation->DeepCopyInstruction(root)); + TF_ASSIGN_OR_RETURN( + HloInstruction * deep_copy, + computation->DeepCopyInstruction(root, &indices_to_copy)); for (HloInstruction* user : users) { TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy)); } @@ -1128,6 +1158,7 @@ static int64 GetNumExistingCopies(const HloModule* module) { Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module) { + XLA_VLOG_LINES(4, module->ToString()); TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, can_share_buffer_)); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 7f051d4d1b2..7c362b2da44 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -49,6 +49,7 @@ filegroup( "runtime_single_threaded_conv2d.cc", "runtime_single_threaded_fft.cc", "runtime_single_threaded_matmul.cc", + "runtime_topk.cc", ], visibility = [":friends"], ) @@ -64,6 +65,7 @@ filegroup( "runtime_single_threaded_conv2d.h", "runtime_single_threaded_fft.h", "runtime_single_threaded_matmul.h", + "runtime_topk.h", ], visibility = [":friends"], ) @@ -134,13 +136,16 @@ cc_library( "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:dump", + "//tensorflow/compiler/xla/service:topk_rewriter", "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:rng_bit_generator_expander", "//tensorflow/compiler/xla/service:tree_reduction_rewriter", "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", + "//tensorflow/compiler/xla/service:conditional_canonicalizer", "//tensorflow/compiler/xla/service:conditional_to_select", "//tensorflow/compiler/xla/service:slow_operation_alarm", "//tensorflow/compiler/xla/service:scatter_expander", + "//tensorflow/compiler/xla/service:comparison_expander", "//tensorflow/compiler/xla/service:slice_sinker", "//tensorflow/compiler/xla:cpu_function_runtime", "//tensorflow/compiler/xla:literal", @@ -179,6 +184,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", "//tensorflow/compiler/xla/service:llvm_compiler", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:rng_expander", "//tensorflow/compiler/xla/service:sort_simplifier", @@ -229,6 +235,7 @@ cc_library( ":runtime_fft", ":runtime_fork_join", ":runtime_key_value_sort", + ":runtime_topk", ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", @@ -615,7 +622,8 @@ cc_library( deps = [ ":runtime_lightweight_check", "//tensorflow/compiler/xla:executable_run_options", - "//tensorflow/core/kernels:eigen_helpers_no_mkl", + "//tensorflow/core/kernels:eigen_contraction_kernel", + "//tensorflow/core/kernels:eigen_helpers", "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:types", @@ -703,6 +711,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":runtime_lightweight_check", + "//tensorflow/core/kernels:eigen_contraction_kernel", "//tensorflow/core/kernels:eigen_helpers", "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:types", @@ -756,6 +765,19 @@ cc_library( ], ) +cc_library( + name = "runtime_topk", + srcs = ["runtime_topk.cc"], + hdrs = ["runtime_topk.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core/platform:dynamic_annotations", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:types", + ], +) + cc_library( name = "runtime_fork_join", srcs = ["runtime_fork_join.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 5464cfee082..39d2b11ad37 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -54,6 +54,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/cholesky_expander.h" +#include "tensorflow/compiler/xla/service/comparison_expander.h" +#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/conditional_to_select.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" @@ -76,6 +78,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/dynamic_padder.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" @@ -103,6 +106,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/slice_sinker.h" #include "tensorflow/compiler/xla/service/slow_operation_alarm.h" #include "tensorflow/compiler/xla/service/sort_simplifier.h" +#include "tensorflow/compiler/xla/service/topk_rewriter.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h" #include "tensorflow/compiler/xla/service/triangular_solve_expander.h" @@ -258,6 +262,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -284,6 +289,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*rewrite_grad_op=*/true); pipeline.AddPass( /*expansion_type=*/LogisticExpansionType::kExp); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -300,6 +306,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pass.AddPass(options); pass.AddPass(); pass.AddPass(); + pass.AddPass(GatherExpander::kEliminateSimpleGathers); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -318,6 +325,9 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pass.AddPass(); pass.AddPass(); } + pipeline.AddPass([](const HloSortInstruction* sort, int64) { + return sort->operand(0)->shape().element_type() == F32; + }); pipeline.AddPass(); pipeline.AddPass( [&](const HloInstruction& dot, @@ -614,10 +624,9 @@ StatusOr> CpuCompiler::RunBackend( // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; - auto llvm_module = absl::make_unique( - "__compute_module", - mlir_context.getRegisteredDialect() - ->getLLVMContext()); + llvm::LLVMContext llvm_context; + auto llvm_module = + absl::make_unique("__compute_module", llvm_context); auto jit = absl::make_unique( CompilerTargetOptions(module->config()), @@ -826,10 +835,8 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; - llvm::Module llvm_module( - "__compute_module", - mlir_context.getRegisteredDialect() - ->getLLVMContext()); + llvm::LLVMContext llvm_context; + llvm::Module llvm_module("__compute_module", llvm_context); llvm_module.setDataLayout(target_machine->createDataLayout()); llvm_module.setTargetTriple(triple.getTriple()); if (pic_level != llvm::PICLevel::NotPIC) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 0abcc91a1d7..7431e829b8e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -247,6 +247,12 @@ StatusOr CpuExecutable::CreateResultShapedBuffer( ExecutionInput& input = arguments[alias->parameter_number]; MaybeOwningDeviceMemory* maybe_owning_memory = input.MutableBuffer(alias->parameter_index); + if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) { + return InvalidArgument( + "An input was configured to be must-alias at " + "compile time but not donated at runtime: %s", + alias->ToString()); + } if (absl::optional owning = maybe_owning_memory->Release()) { // If the caller passes the ownership of the device memory, reuse it diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index c0222010fd9..ff654c83d61 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -25,7 +25,6 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaForceEnableExperimentalLlvmIrGemm = "xla_force_enable_experimental_llvm_ir_gemm"; -const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -64,12 +63,6 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } -bool UseLinalgForDot(const HloModuleConfig& config) { - const auto& extra_options_map = - config.debug_options().xla_backend_extra_options(); - return extra_options_map.count(kXlaUseLinalgForDot) > 0; -} - static absl::string_view RemoveSuffix(absl::string_view str, absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 5d25aef6912..99e6702d14a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,7 +27,6 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config); -bool UseLinalgForDot(const HloModuleConfig& config); absl::optional LlvmIrGemvTilingFactor(const HloModuleConfig& config); absl::optional> LlvmIrGemmTileSize( const HloModuleConfig& config); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 2231ecfa1e8..5bee6049a5e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -117,6 +117,7 @@ extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; extern const char* const kKeyValueSortSymbolName = "__xla_cpu_runtime_KeyValueSort"; +extern const char* const kTopKF32SymbolName = "__xla_cpu_runtime_TopKF32"; extern const char* const kTracingStartSymbolName = "__xla_cpu_runtime_TracingStart"; extern const char* const kTracingEndSymbolName = "__xla_cpu_runtime_TracingEnd"; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index ee75b97e4dc..eb24e0bc334 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -72,6 +72,7 @@ extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; extern const char* const kKeyValueSortSymbolName; +extern const char* const kTopKF32SymbolName; extern const char* const kAllReduceSymbolName; extern const char* const kCollectivePermuteSymbolName; extern const char* const kReplicaIdSymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index ee4bcf4cd35..2b3865b4dba 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -270,11 +270,48 @@ Status DotOpEmitter::EmitLinalgMatmul() { return EmitMlirFuncAndCall( mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr, operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) { + CHECK_EQ(dot_info_.dim_nums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dot_info_.dim_nums.rhs_contracting_dimensions_size(), 1); + mlir::MLIRContext* context = builder->getContext(); mlir::edsc::ScopedContext scope(*builder, function.getLoc()); mlir::Value a = function.getArgument(0), b = function.getArgument(1), c = function.getArgument(2); - mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{}, - mlir::ValueRange{b, c, a}); + + llvm::SmallVector b_exprs( + dot_info_.lhs_shape.rank()); + llvm::SmallVector c_exprs( + dot_info_.rhs_shape.rank()); + + llvm::SmallVector parallel_exprs; + mlir::AffineExpr reduce_expr; + for (int i = 0; i != dot_info_.result_shape.rank(); ++i) { + parallel_exprs.push_back(mlir::getAffineDimExpr(i, context)); + } + reduce_expr = + mlir::getAffineDimExpr(dot_info_.result_shape.rank(), context); + + // The reduction expr is shared for both inputs. + b_exprs[dot_info_.dim_nums.lhs_contracting_dimensions(0)] = reduce_expr; + c_exprs[dot_info_.dim_nums.rhs_contracting_dimensions(0)] = reduce_expr; + + // Fill in the remaining parallel exprs. + int par_expr_num = 0; + for (auto* v : {&b_exprs, &c_exprs}) { + for (auto& e : *v) { + if (!e) { + e = parallel_exprs[par_expr_num++]; + } + } + } + + llvm::SmallVector types( + parallel_exprs.size(), mlir::IteratorType::Parallel); + types.push_back(mlir::IteratorType::Reduction); + + mlir::edsc::StructuredIndexed s_a(a), s_b(b), s_c(c); + mlir::edsc::makeGenericLinalgOp(types, {s_b(b_exprs), s_c(c_exprs)}, + {s_a(parallel_exprs)}, + mlir::edsc::ops::macRegionBuilder); mlir::edsc::intrinsics::std_ret(); mlir::linalg::LinalgTilingOptions tilingOptions; @@ -283,13 +320,13 @@ Status DotOpEmitter::EmitLinalgMatmul() { target_machine_features_.minimum_alignment_for_allocation( ShapeUtil::ByteSizeOf(dot_info_.result_shape)); mlir_strategy::MatmulCodegenStrategy strategy; - strategy.tile(tilingOptions) - .promote( + strategy.tile(tilingOptions) + .promote( mlir::linalg::LinalgPromotionOptions() .setAlignment(alignment) .setUseFullTileBuffersByDefault(true) .setUseAlloca(true)) - .vectorize() + .vectorize() .setVectorTransformsOptions( mlir::vector::VectorTransformsOptions() .setVectorTransformsOptions( @@ -986,9 +1023,7 @@ DotImplementationStrategy GetDotImplementationStrategy( if (IsAlignedGemm(dot_info, target_machine_features)) { if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) { - return options::UseLinalgForDot(config) - ? DotImplementationStrategy::kLinalgMatmul - : DotImplementationStrategy::kTiledLlvmIrGemm; + return DotImplementationStrategy::kLinalgMatmul; } return DotImplementationStrategy::kEigen; } diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index ebb2df23805..242f3c6ceb7 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1397,6 +1397,7 @@ Status IrEmitter::HandleAllToAll(HloInstruction* instruction) { Status IrEmitter::HandleCollectivePermute(HloInstruction* crs) { auto* instr = Cast(crs); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(instr)); std::string source_target_pairs = absl::StrJoin( instr->source_target_pairs(), ",", absl::PairFormatter("=")); llvm::Value* source_target_pairs_v = @@ -2386,6 +2387,45 @@ Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { return Status::OK(); } +Status IrEmitter::HandleTopK(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + const HloInstruction* input = hlo->operand(0); + const int64 k = hlo->shape().tuple_shapes(0).dimensions().back(); + const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2; + TF_RET_CHECK(input->shape().element_type() == F32); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( + hlo->shape().tuple_shapes(0).layout())); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( + hlo->shape().tuple_shapes(1).layout())); + TF_RET_CHECK( + LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout())); + + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice, + assignment_.GetUniqueSlice(hlo->operand(0), {})); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_values_slice, + assignment_.GetUniqueSlice(hlo, {0})); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice out_indices_slice, + assignment_.GetUniqueSlice(hlo, {1})); + llvm::Value* values_ptr = + EmitBufferPointer(values_slice, hlo->operand(0)->shape()); + llvm::Value* out_values_ptr = + EmitBufferPointer(out_values_slice, hlo->shape().tuple_shapes(0)); + llvm::Value* out_indices_ptr = + EmitBufferPointer(out_indices_slice, hlo->shape().tuple_shapes(1)); + EmitCallToFunc( + runtime::kTopKF32SymbolName, + {b_.getInt64(has_batch ? input->shape().dimensions(0) : 1), + b_.getInt64(input->shape().dimensions().back()), b_.getInt64(k), + BitCast(values_ptr, b_.getFloatTy()->getPointerTo()), + BitCast(out_values_ptr, b_.getFloatTy()->getPointerTo()), + BitCast(out_indices_ptr, b_.getInt32Ty()->getPointerTo())}, + b_.getVoidTy()); + + llvm_ir::EmitTuple(GetIrArrayFor(hlo), {out_values_ptr, out_indices_ptr}, + &b_); + return Status::OK(); +} + Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { if (custom_call->custom_call_target() == "PadToStatic") { return HandlePadToStatic(custom_call); @@ -2393,6 +2433,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { if (custom_call->custom_call_target() == "SliceToDynamic") { return HandleSliceToDynamic(custom_call); } + if (custom_call->custom_call_target() == "TopK") { + return HandleTopK(custom_call); + } absl::Span operands(custom_call->operands()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = @@ -3037,10 +3080,21 @@ void IrEmitter::TracingState::EmitTracingEnd(llvm::IRBuilder<>* b, {b->CreateBitCast(run_options, void_ptr_type), activity_id}); } +namespace { +bool IsHloVeryCheap(const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kBitcast || + hlo->opcode() == HloOpcode::kTuple || + hlo->opcode() == HloOpcode::kGetTupleElement || + hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kConstant; +} +} // namespace + Status IrEmitter::Preprocess(HloInstruction* hlo) { VLOG(3) << "Visiting: " << hlo->ToString(); - if (instruction_to_profile_idx_.count(hlo)) { - // Only trace the same HLOs that the profiler does. + // When profiling is enabled, trace the same HLOs that the profiler does. + if (instruction_to_profile_idx_.count(hlo) || + (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) { tracing_state_.EmitTracingStart(&b_, hlo, GetExecutableRunOptionsArgument()); profiling_state_.RecordCycleStart(&b_, hlo); @@ -3052,8 +3106,9 @@ Status IrEmitter::Postprocess(HloInstruction* hlo) { if (auto* prof_counter = GetProfileCounterFor(*hlo)) { profiling_state_.RecordCycleDelta(&b_, hlo, prof_counter); } - // Only trace the same HLOs that the profiler does. - if (instruction_to_profile_idx_.count(hlo)) { + // When profiling is enabled, trace the same HLOs that the profiler does. + if (instruction_to_profile_idx_.count(hlo) || + (hlo_module_config_.cpu_traceme_enabled() && !IsHloVeryCheap(hlo))) { tracing_state_.EmitTracingEnd(&b_, hlo, GetExecutableRunOptionsArgument()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 3955deefbea..f136e3470e5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -190,6 +190,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, private: Status HandleSliceToDynamic(HloInstruction* hlo); Status HandlePadToStatic(HloInstruction* hlo); + Status HandleTopK(HloInstruction* hlo); Status HandleAllReduceSingleReplica(HloInstruction* crs); Status HandleAllReduceMultipleReplica(HloInstruction* crs); diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc index ff48f554ce6..ae23f224207 100644 --- a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc @@ -32,7 +32,8 @@ namespace cpu { namespace { // Lower an MLIR module to an LLVM module. -std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module) { +std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module, + llvm::LLVMContext *context) { // When set, the LLVM backend will be allowed to reassociate floating-point // reductions, which enables much more efficient "horizontal" SIMD // implementations. @@ -47,7 +48,7 @@ std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module) { mlir::LowerVectorToLLVMOptions().setReassociateFPReductions( kReassociateFPReductions))); CHECK(succeeded(manager.run(*module))); - return mlir::translateModuleToLLVMIR(*module); + return mlir::translateModuleToLLVMIR(*module, *context); } // Get arguments to pass a memref to an mlir function. @@ -114,7 +115,8 @@ Status EmitMlirFuncAndCall( emitter(&op_builder, function); // Now link it all into the main LLVM module. - auto mlir_llvm_module = MakeLLVMModule(std::move(mlir_module)); + auto mlir_llvm_module = + MakeLLVMModule(std::move(mlir_module), &b->getContext()); mlir_llvm_module->setDataLayout(llvm_module->getDataLayout()); llvm::Linker::linkModules( *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc index 84cb41a8f17..eac0371b76d 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc @@ -23,16 +23,18 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" -using tensorflow::int64; - TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF32( const void* run_options_ptr, float* out, float* lhs, float* rhs, - int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels, - int64 kernel_rows, int64 kernel_cols, int64 kernel_channels, - int64 kernel_filters, int64 output_rows, int64 output_cols, - int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom, - int64 padding_left, int64 padding_right, int64 lhs_row_dilation, - int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { + tensorflow::int64 input_batch, tensorflow::int64 input_rows, + tensorflow::int64 input_cols, tensorflow::int64 input_channels, + tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols, + tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters, + tensorflow::int64 output_rows, tensorflow::int64 output_cols, + tensorflow::int64 row_stride, tensorflow::int64 col_stride, + tensorflow::int64 padding_top, tensorflow::int64 padding_bottom, + tensorflow::int64 padding_left, tensorflow::int64 padding_right, + tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation, + tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); @@ -46,13 +48,17 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF32( TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols, - int64 input_channels, int64 kernel_rows, int64 kernel_cols, - int64 kernel_channels, int64 kernel_filters, int64 output_rows, - int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top, - int64 padding_bottom, int64 padding_left, int64 padding_right, - int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation, - int64 rhs_col_dilation) { + Eigen::half* rhs, tensorflow::int64 input_batch, + tensorflow::int64 input_rows, tensorflow::int64 input_cols, + tensorflow::int64 input_channels, tensorflow::int64 kernel_rows, + tensorflow::int64 kernel_cols, tensorflow::int64 kernel_channels, + tensorflow::int64 kernel_filters, tensorflow::int64 output_rows, + tensorflow::int64 output_cols, tensorflow::int64 row_stride, + tensorflow::int64 col_stride, tensorflow::int64 padding_top, + tensorflow::int64 padding_bottom, tensorflow::int64 padding_left, + tensorflow::int64 padding_right, tensorflow::int64 lhs_row_dilation, + tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation, + tensorflow::int64 rhs_col_dilation) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h index 193c25f2a4b..ec634e7f738 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h @@ -19,6 +19,10 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_spatial_convolutions.h" #include "tensorflow/core/platform/types.h" +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "tensorflow/core/kernels/eigen_contraction_kernel.h" +#endif + // 'tensorflow' namespace is used so that int64 and other types don't require // qualification. namespace tensorflow { diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc index 0d4e7055ddb..2cee58162fc 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -25,21 +25,16 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -namespace { -using tensorflow::int32; -using tensorflow::int64; -} // namespace - TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( - int64 a, int64 b, int64 c, char** values, int32 values_count, - int32* values_primitive_type_size_in_bytes, bool is_stable, - char* run_options, int64* prof_counters, + tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, char** values, tensorflow::int32 values_count, + tensorflow::int32* values_primitive_type_size_in_bytes, bool is_stable, + char* run_options, tensorflow::int64* prof_counters, void (*less_than)(char*, char*, char**, char**, tensorflow::int64*)) { // 'values' and 'values_primitive_type_size_in_bytes' are managed by the JIT // code, so msan can't tell they are initialized. TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, values_count * sizeof(char*)); TF_ANNOTATE_MEMORY_IS_INITIALIZED(values_primitive_type_size_in_bytes, - values_count * sizeof(int32)); + values_count * sizeof(tensorflow::int32)); // High-level idea of the iteration/sorting logic: // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the @@ -50,16 +45,16 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( // 'base_offset' value which points to the first element in that row, and add // i * c for accessing the 'i'-th element in that row. - int64 sort_dimension_elements = b; - int64 num_iteration_elements = a * c; - int64 sort_dimension_offset = c; + tensorflow::int64 sort_dimension_elements = b; + tensorflow::int64 num_iteration_elements = a * c; + tensorflow::int64 sort_dimension_offset = c; - std::unique_ptr indices(new int64[sort_dimension_elements]); + std::unique_ptr indices(new tensorflow::int64[sort_dimension_elements]); std::unique_ptr comparison_values(new char*[2 * values_count]); std::iota(indices.get(), indices.get() + sort_dimension_elements, 0); std::unique_ptr reordered_values( new std::string[sort_dimension_elements]); - for (int64 index = 0; index < num_iteration_elements; ++index) { + for (tensorflow::int64 index = 0; index < num_iteration_elements; ++index) { // If the sort should be stable, we have to reinitialize indices to iota to // guarantee that we still keep the relative order in case of ties. if (is_stable && index > 0) { @@ -71,14 +66,14 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( // calculating the base offset, we need to multiply the index into the 'a' // dimension with 'b' * 'c'. // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'. - int64 base_offset = + tensorflow::int64 base_offset = index % sort_dimension_offset + (index - index % sort_dimension_offset) * sort_dimension_elements; - auto compare_function = [&](int64 a, int64 b) -> bool { - for (int32 i = 0; i < values_count; ++i) { - int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * + auto compare_function = [&](tensorflow::int64 a, tensorflow::int64 b) -> bool { + for (tensorflow::int32 i = 0; i < values_count; ++i) { + tensorflow::int64 memory_index_lhs = (base_offset + a * sort_dimension_offset) * values_primitive_type_size_in_bytes[i]; - int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * + tensorflow::int64 memory_index_rhs = (base_offset + b * sort_dimension_offset) * values_primitive_type_size_in_bytes[i]; comparison_values[i * 2] = values[i] + memory_index_lhs; comparison_values[i * 2 + 1] = values[i] + memory_index_rhs; @@ -97,9 +92,9 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( } // Reorder the values according to the order defined by 'indices'. - for (int32 idx = 0; idx < values_count; ++idx) { - for (int64 i = 0; i < sort_dimension_elements; ++i) { - int64 memory_index = + for (tensorflow::int32 idx = 0; idx < values_count; ++idx) { + for (tensorflow::int64 i = 0; i < sort_dimension_elements; ++i) { + tensorflow::int64 memory_index = (base_offset + indices[i] * sort_dimension_offset) * values_primitive_type_size_in_bytes[idx]; @@ -107,8 +102,8 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSort( std::string(values[idx] + memory_index, values_primitive_type_size_in_bytes[idx]); } - for (int64 i = 0; i < sort_dimension_elements; ++i) { - int64 memory_index = (base_offset + i * sort_dimension_offset) * + for (tensorflow::int64 i = 0; i < sort_dimension_elements; ++i) { + tensorflow::int64 memory_index = (base_offset + i * sort_dimension_offset) * values_primitive_type_size_in_bytes[idx]; memcpy(values[idx] + memory_index, reordered_values[i].c_str(), values_primitive_type_size_in_bytes[idx]); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc index 35db15fed2c..7e19b383d6f 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc @@ -27,9 +27,6 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_contraction_kernel.h" #endif -using tensorflow::int32; -using tensorflow::int64; - namespace { bool Is16BytesAligned(void* ptr) { @@ -37,19 +34,20 @@ bool Is16BytesAligned(void* ptr) { } template -void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { +void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, + tensorflow::int64 m, tensorflow::int64 n, tensorflow::int64 k, + tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); - int64 lhs_rows = m; - int64 lhs_cols = k; + tensorflow::int64 lhs_rows = m; + tensorflow::int64 lhs_cols = k; if (transpose_lhs) { std::swap(lhs_rows, lhs_cols); } - int64 rhs_rows = k; - int64 rhs_cols = n; + tensorflow::int64 rhs_rows = k; + tensorflow::int64 rhs_cols = n; if (transpose_rhs) { std::swap(rhs_rows, rhs_cols); } @@ -75,8 +73,9 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, template void MatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs, - int64 m, int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { bool all_buffers_16b_aligned = Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); @@ -94,45 +93,52 @@ void MatMulDispatch(const void* run_options_ptr, T* out, T* lhs, T* rhs, TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + Eigen::half* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF32( - const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { + const void* run_options_ptr, float* out, float* lhs, float* rhs, + tensorflow::int64 m, tensorflow::int64 n, tensorflow::int64 k, + tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs) { MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulF64( - const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { + const void* run_options_ptr, double* out, double* lhs, double* rhs, + tensorflow::int64 m, tensorflow::int64 n, tensorflow::int64 k, + tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs) { MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC64( const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, int32 transpose_rhs) { + std::complex* lhs, std::complex* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulC128( const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, int32 transpose_rhs) { + std::complex* lhs, std::complex* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { MatMulDispatch>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenMatMulS32( - const void* run_options_ptr, int32* out, int32* lhs, int32* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, - transpose_rhs); + const void* run_options_ptr, tensorflow::int32* out, tensorflow::int32* lhs, + tensorflow::int32* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + MatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, + transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc index 5afccc6a86e..360ce57e808 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc @@ -19,18 +19,20 @@ limitations under the License. #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/types.h" -using tensorflow::int64; - TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedConvF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols, - int64 input_channels, int64 kernel_rows, int64 kernel_cols, - int64 kernel_channels, int64 kernel_filters, int64 output_rows, - int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top, - int64 padding_bottom, int64 padding_left, int64 padding_right, - int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation, - int64 rhs_col_dilation) { + Eigen::half* rhs, tensorflow::int64 input_batch, + tensorflow::int64 input_rows, tensorflow::int64 input_cols, + tensorflow::int64 input_channels, tensorflow::int64 kernel_rows, + tensorflow::int64 kernel_cols, tensorflow::int64 kernel_channels, + tensorflow::int64 kernel_filters, tensorflow::int64 output_rows, + tensorflow::int64 output_cols, tensorflow::int64 row_stride, + tensorflow::int64 col_stride, tensorflow::int64 padding_top, + tensorflow::int64 padding_bottom, tensorflow::int64 padding_left, + tensorflow::int64 padding_right, tensorflow::int64 lhs_row_dilation, + tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation, + tensorflow::int64 rhs_col_dilation) { tensorflow::xla::EigenConvImpl( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, @@ -42,12 +44,16 @@ __xla_cpu_runtime_EigenSingleThreadedConvF16( TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedConvF32( const void* run_options_ptr, float* out, float* lhs, float* rhs, - int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels, - int64 kernel_rows, int64 kernel_cols, int64 kernel_channels, - int64 kernel_filters, int64 output_rows, int64 output_cols, - int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom, - int64 padding_left, int64 padding_right, int64 lhs_row_dilation, - int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { + tensorflow::int64 input_batch, tensorflow::int64 input_rows, + tensorflow::int64 input_cols, tensorflow::int64 input_channels, + tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols, + tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters, + tensorflow::int64 output_rows, tensorflow::int64 output_cols, + tensorflow::int64 row_stride, tensorflow::int64 col_stride, + tensorflow::int64 padding_top, tensorflow::int64 padding_bottom, + tensorflow::int64 padding_left, tensorflow::int64 padding_right, + tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation, + tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation) { tensorflow::xla::EigenConvImpl( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc index c7601f939c7..a8112c1106b 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc @@ -23,9 +23,6 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_contraction_kernel.h" #endif -using tensorflow::int32; -using tensorflow::int64; - namespace { bool Is16BytesAligned(void* ptr) { @@ -33,16 +30,17 @@ bool Is16BytesAligned(void* ptr) { } template -void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, - int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) { - int64 lhs_rows = m; - int64 lhs_cols = k; +void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, + tensorflow::int64 m, tensorflow::int64 n, tensorflow::int64 k, + tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs) { + tensorflow::int64 lhs_rows = m; + tensorflow::int64 lhs_cols = k; if (transpose_lhs) { std::swap(lhs_rows, lhs_cols); } - int64 rhs_rows = k; - int64 rhs_cols = n; + tensorflow::int64 rhs_rows = k; + tensorflow::int64 rhs_cols = n; if (transpose_rhs) { std::swap(rhs_rows, rhs_cols); } @@ -67,8 +65,10 @@ void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m, template void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, - T* rhs, int64 m, int64 n, int64 k, - int32 transpose_lhs, int32 transpose_rhs) { + T* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, + tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { bool all_buffers_16b_aligned = Is16BytesAligned(out) && Is16BytesAligned(lhs) && Is16BytesAligned(rhs); @@ -86,28 +86,27 @@ void SingleThreadedMatMulDispatch(const void* run_options_ptr, T* out, T* lhs, TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedMatMulF16( const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, - Eigen::half* rhs, int64 m, int64 n, int64 k, int32 transpose_lhs, - int32 transpose_rhs) { + Eigen::half* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulF32(const void* run_options_ptr, - float* out, float* lhs, - float* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +__xla_cpu_runtime_EigenSingleThreadedMatMulF32( + const void* run_options_ptr, float* out, float* lhs, float* rhs, + tensorflow::int64 m, tensorflow::int64 n, tensorflow::int64 k, + tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs) { SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, - double* out, double* lhs, - double* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { +__xla_cpu_runtime_EigenSingleThreadedMatMulF64( + const void* run_options_ptr, double* out, double* lhs, double* rhs, + tensorflow::int64 m, tensorflow::int64 n, tensorflow::int64 k, + tensorflow::int32 transpose_lhs, tensorflow::int32 transpose_rhs) { SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } @@ -115,8 +114,9 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulF64(const void* run_options_ptr, TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedMatMulC64( const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, int32 transpose_rhs) { + std::complex* lhs, std::complex* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { SingleThreadedMatMulDispatch>( run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } @@ -124,18 +124,19 @@ __xla_cpu_runtime_EigenSingleThreadedMatMulC64( TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedMatMulC128( const void* run_options_ptr, std::complex* out, - std::complex* lhs, std::complex* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, int32 transpose_rhs) { + std::complex* lhs, std::complex* rhs, tensorflow::int64 m, + tensorflow::int64 n, tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { SingleThreadedMatMulDispatch>( run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_EigenSingleThreadedMatMulS32(const void* run_options_ptr, - int32* out, int32* lhs, - int32* rhs, int64 m, int64 n, - int64 k, int32 transpose_lhs, - int32 transpose_rhs) { - SingleThreadedMatMulDispatch(run_options_ptr, out, lhs, rhs, m, n, k, - transpose_lhs, transpose_rhs); +__xla_cpu_runtime_EigenSingleThreadedMatMulS32( + const void* run_options_ptr, tensorflow::int32* out, tensorflow::int32* lhs, + tensorflow::int32* rhs, tensorflow::int64 m, tensorflow::int64 n, + tensorflow::int64 k, tensorflow::int32 transpose_lhs, + tensorflow::int32 transpose_rhs) { + SingleThreadedMatMulDispatch( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); } diff --git a/tensorflow/compiler/xla/service/cpu/runtime_topk.cc b/tensorflow/compiler/xla/service/cpu/runtime_topk.cc new file mode 100644 index 00000000000..5174a3329fb --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_topk.cc @@ -0,0 +1,76 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h" + +#include +#include +#include +#include + +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" + +template +static void TopK(tensorflow::int64 batch_size, tensorflow::int64 input_size, + tensorflow::int64 k, const T* values, T* out_values, + tensorflow::int32* out_indices) { + // 'values' is managed by the JIT code, so msan can't tell they are + // initialized. + TF_ANNOTATE_MEMORY_IS_INITIALIZED(values, + input_size * batch_size * sizeof(T)); + + std::vector temp_indices(input_size); + for (tensorflow::int64 batch = 0; batch != batch_size; ++batch) { + std::iota(temp_indices.begin(), temp_indices.end(), 0); + + const T* values_batch = values + batch * input_size; + + auto convert_to_int = [](T value) { + tensorflow::uint32 x; + std::memcpy(&x, &value, sizeof(x)); + return static_cast(x) < 0 + ? std::numeric_limits::max() - x + : x; + }; + + auto kth_element = temp_indices.begin() + k; + std::partial_sort(temp_indices.begin(), kth_element, temp_indices.end(), + [&](size_t i1, size_t i2) { + // Do the comparison in integers to enforce a total + // order of -NaN < -Inf < -0 < +0 < +Inf < +NaN. + tensorflow::int32 v1 = convert_to_int(values_batch[i1]); + tensorflow::int32 v2 = convert_to_int(values_batch[i2]); + if (v1 == v2) { + return i1 < i2; // Stabilize sorting. + } + return v1 > v2; + }); + + T* out_values_batch = out_values + batch * k; + tensorflow::int32* out_indices_batch = out_indices + batch * k; + std::copy(temp_indices.begin(), kth_element, out_indices_batch); + for (tensorflow::int64 i = 0; i < k; i++) { + out_values_batch[i] = values_batch[temp_indices[i]]; + } + } +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_TopKF32( + tensorflow::int64 batch_size, tensorflow::int64 input_size, + tensorflow::int64 k, const float* values, float* out_values, + tensorflow::int32* out_indices) { + TopK(batch_size, input_size, k, values, out_values, out_indices); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_topk.h b/tensorflow/compiler/xla/service/cpu/runtime_topk.h new file mode 100644 index 00000000000..de69c0603e3 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_topk.h @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// Calculates `batch_size` topk operations with `input_size` inputs each. The +// outputs are written to `out_values` and `out_indices`. +extern void __xla_cpu_runtime_TopKF32(tensorflow::int64 batch_size, + tensorflow::int64 input_size, + tensorflow::int64 k, const float* values, + float* out_values, + tensorflow::int32* out_indices); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_TOPK_H diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 631c6985b03..28508bde4cd 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_topk.h" #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/types.h" @@ -270,6 +271,7 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort); + REGISTER_CPU_RUNTIME_SYMBOL(TopKF32); REGISTER_CPU_RUNTIME_SYMBOL(TracingStart); REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index d7c50dce3ca..527071d5f31 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -253,6 +253,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "cpu_topk_test", + srcs = ["cpu_topk_test.cc"], + deps = [ + ":cpu_codegen_test", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:sorting", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu:test_header_helper", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "cpu_vectorization_test", srcs = ["cpu_vectorization_test.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc new file mode 100644 index 00000000000..b7647fb4b16 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_topk_test.cc @@ -0,0 +1,86 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/client/lib/sorting.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" + +namespace xla { +namespace cpu { +namespace { + +using CpuTopKTest = CpuCodegenTest; + +TEST_F(CpuTopKTest, CallRuntimeUnbatched) { + XlaBuilder builder(TestName()); + XlaOp input = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {100}), "input"); + TopK(input, 10); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSERT_OK_AND_ASSIGN( + auto module, HloModule::CreateFromProto(xla_computation.proto(), config)); + + constexpr char filecheck_pattern[] = R"( + CHECK: call void @__xla_cpu_runtime_TopKF32(i64 1, i64 100, i64 10, + )"; + + CpuAotCompilationOptions options{ + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/true); +} + +TEST_F(CpuTopKTest, CallRuntimeBatched) { + XlaBuilder builder(TestName()); + XlaOp input = + Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {5, 100}), "input"); + TopK(input, 10); + TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSERT_OK_AND_ASSIGN( + auto module, HloModule::CreateFromProto(xla_computation.proto(), config)); + + constexpr char filecheck_pattern[] = R"( + CHECK: call void @__xla_cpu_runtime_TopKF32(i64 5, i64 100, i64 10, + )"; + + CpuAotCompilationOptions options{ + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc index fcdf85d5ecb..4670ce6940a 100644 --- a/tensorflow/compiler/xla/service/dot_as_convolution_util.cc +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.cc @@ -24,6 +24,31 @@ limitations under the License. namespace xla { namespace dot_as_convolution_util { +bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size) { + // A parallel batch dimension in DotGeneral is represented as a + // spatial dimension with window size B (batch dimension size), + // stride B - 1, and base dilation B. + if (lhs_size == wd.size() && lhs_size == wd.base_dilation() && + ((std::max(1, lhs_size - 1) == wd.stride() && + wd.window_dilation() == 1) || + (std::max(1, lhs_size - 1) == wd.window_dilation() && + wd.stride() == 1)) && + wd.padding_high() == 0 && wd.padding_low() == 0 && + !wd.window_reversal()) { + return true; + } + + // Aternative representation of a batch dimension. + if (wd.size() == lhs_size && wd.padding_high() == lhs_size - 1 && + wd.padding_low() == lhs_size - 1 && wd.window_reversal() && + wd.window_dilation() == 1 && wd.stride() == lhs_size && + wd.base_dilation() == lhs_size - 1) { + return true; + } + + return false; +} + /* static */ absl::optional ParseDotGeneralFromConvolution(const HloInstruction* conv) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); @@ -49,14 +74,7 @@ ParseDotGeneralFromConvolution(const HloInstruction* conv) { int64 rhs_size = conv->operand(1)->shape().dimensions(rhs); int64 output = conv_dims.output_spatial_dimensions(i); const auto& wd = conv->window().dimensions(i); - if (lhs_size == wd.size() && - std::max(1, lhs_size - 1) == wd.stride() && - lhs_size == wd.base_dilation() && wd.window_dilation() == 1 && - wd.padding_high() == 0 && wd.padding_low() == 0 && - !wd.window_reversal()) { - // A batch dimension in DotGeneral is represented as a spatial dimension - // with window size B (batch dimension size), stride B - 1, and base - // dilation B. + if (ConvSpatialDimensionIsParallel(wd, lhs_size)) { dims.batch_dims.push_back({lhs, rhs, output, i}); } else if (lhs_size == wd.size() && wd.base_dilation() == 1 && wd.window_dilation() == 1 && wd.padding_high() == 0 && diff --git a/tensorflow/compiler/xla/service/dot_as_convolution_util.h b/tensorflow/compiler/xla/service/dot_as_convolution_util.h index a3e829a3d31..6a7cacf812d 100644 --- a/tensorflow/compiler/xla/service/dot_as_convolution_util.h +++ b/tensorflow/compiler/xla/service/dot_as_convolution_util.h @@ -62,6 +62,12 @@ CreateShardedConvForDotGeneralConvolution( const DotGeneralAsConvolutionDimsInfo& dot_dnums, HloInstruction* sharded_lhs_hlo, HloInstruction* sharded_rhs_hlo); +// Check if a spatial dim is parallel batch dimension. +// A parallel batch dimension in DotGeneral is represented as a spatial +// dimension with window size B (batch dimension size), stride B - 1, and base +// dilation B. +bool ConvSpatialDimensionIsParallel(const WindowDimension& wd, int64 lhs_size); + } // namespace dot_as_convolution_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 6ebbf622614..36429d3d755 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -148,15 +148,12 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleDomain(HloInstruction* hlo) override; private: - using DimensionConstraint = DynamicDimensionInference::DimensionConstraint; using OperandDynamicDimensionFn = std::function; + int64 operand_index, HloInstruction* dynamic_size)>; using DynamicDimensionFn = std::function; + ShapeIndex index, int64 dimension, HloInstruction* dynamic_size)>; Status ForEachOperandDynamicDimension(HloInstruction* inst, const OperandDynamicDimensionFn&); @@ -184,8 +181,7 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status DynamicDimensionInferenceVisitor::DefaultAction(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { return UnimplementedStrCat( "Asked to propagate a dynamic dimension from hlo ", operand->name(), "@", index.ToString(), "@", dimension, " to hlo ", hlo->ToString(), @@ -197,13 +193,11 @@ Status DynamicDimensionInferenceVisitor::HandleGetTupleElement( HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { if (hlo->tuple_index() == index[0]) { ShapeIndex new_index = ShapeIndexView(index).ConsumeFront().ToShapeIndex(); - parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size, - constraint); + parent_->SetDynamicSize(hlo, new_index, dimension, dynamic_size); } return Status::OK(); }); @@ -212,11 +206,9 @@ Status DynamicDimensionInferenceVisitor::HandleGetTupleElement( Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { index.push_front(operand_index); - parent_->SetDynamicSize(hlo, index, dimension, dynamic_size, - constraint); + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); return Status::OK(); }); } @@ -224,11 +216,9 @@ Status DynamicDimensionInferenceVisitor::HandleTuple(HloInstruction* hlo) { Status DynamicDimensionInferenceVisitor::HandleBroadcast(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { int64 broadcast_dim = hlo->dimensions(dimension); - parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size, - constraint); + parent_->SetDynamicSize(hlo, {}, broadcast_dim, dynamic_size); return Status::OK(); }); } @@ -244,8 +234,7 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) { // returns the padded data output and the dynamic sizes of input // dimensions. ShapeIndex data_output = {0}; - parent_->SetDynamicSize(hlo, data_output, i, dynamic_size, - DimensionConstraint(1, 1)); + parent_->SetDynamicSize(hlo, data_output, i, dynamic_size); } } return Status::OK(); @@ -255,15 +244,14 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) { } return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { // Resize custom call should propagate dynamic batch (0) and channel (3) // dimensions. if (hlo->custom_call_target() == "SliceToDynamic" || hlo->custom_call_target() == "Sharding" || (absl::StartsWith(hlo->custom_call_target(), "Resize") && (dimension == 0 || dimension == 3))) { - parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint); + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); return Status::OK(); } return Unimplemented( @@ -274,16 +262,15 @@ Status DynamicDimensionInferenceVisitor::HandleCustomCall(HloInstruction* hlo) { Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) { return ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* operand, ShapeIndex index, - int64 dynamic_dimension, int64 operand_index, - HloInstruction* dynamic_size, DimensionConstraint constraint) { + hlo, + [&](HloInstruction* operand, ShapeIndex index, int64 dynamic_dimension, + int64 operand_index, HloInstruction* dynamic_size) { HloSortInstruction* sort = Cast(hlo); if (sort->values_count() == 0) { - parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size, - constraint); + parent_->SetDynamicSize(hlo, {}, dynamic_dimension, dynamic_size); } else { parent_->SetDynamicSize(hlo, {operand_index}, dynamic_dimension, - dynamic_size, constraint); + dynamic_size); } return Status::OK(); @@ -293,8 +280,7 @@ Status DynamicDimensionInferenceVisitor::HandleSort(HloInstruction* hlo) { Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { if (operand_index != 0) { return Unimplemented( "Dynamic dimension on padding value is not supported"); @@ -311,8 +297,7 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { hlo->parent()->AddInstruction(HloInstruction::CreateBinary( dynamic_size_adjusted->shape(), HloOpcode::kAdd, dynamic_size_adjusted, adjustment)); - parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted, - constraint); + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size_adjusted); return Status::OK(); } else { return Unimplemented( @@ -327,8 +312,7 @@ Status DynamicDimensionInferenceVisitor::HandlePad(HloInstruction* hlo) { Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { HloInstruction* reduce = hlo; int64 operand_count = reduce->operand_count(); bool is_variadic_reduce = operand_count > 2; @@ -354,13 +338,12 @@ Status DynamicDimensionInferenceVisitor::HandleReduce(HloInstruction* hlo) { // reduce has a dynamic dimension, we set all outputs to use the // same dynamic size in corresponding dimensions. for (int64 i = 0; i < operand_count / 2; ++i) { - parent_->SetDynamicSize(reduce, {i}, - dimensions_not_reduced_count, - dynamic_size, constraint); + parent_->SetDynamicSize( + reduce, {i}, dimensions_not_reduced_count, dynamic_size); } } else { parent_->SetDynamicSize(reduce, {}, dimensions_not_reduced_count, - dynamic_size, constraint); + dynamic_size); } return Status::OK(); @@ -378,7 +361,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex operand_shape_index, int64 operand_dimension, int64 operand_index, - HloInstruction* dynamic_size, DimensionConstraint constraint) { + HloInstruction* dynamic_size) { // There are three types of dimensions in a dot: // A. batch dims // B. contracting dims @@ -451,8 +434,7 @@ Status DynamicDimensionInferenceVisitor::HandleDot(HloInstruction* hlo) { // work item to trace that dimension. auto iter = result_dim_mapping.find(operand_dimension); if (iter != result_dim_mapping.end()) { - parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size, - constraint); + parent_->SetDynamicSize(dot, {}, iter->second, dynamic_size); } return Status::OK(); @@ -463,8 +445,7 @@ Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) -> Status { + int64 operand_index, HloInstruction* dynamic_size) -> Status { int64 permuted_dim = -1; for (int64 i = 0; i < hlo->dimensions().size(); ++i) { if (hlo->dimensions()[i] == dimension) { @@ -472,8 +453,7 @@ Status DynamicDimensionInferenceVisitor::HandleTranspose(HloInstruction* hlo) { permuted_dim = i; } } - parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size, - constraint); + parent_->SetDynamicSize(hlo, {}, permuted_dim, dynamic_size); return Status::OK(); }); } @@ -482,8 +462,7 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution( HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { HloInstruction* conv = hlo; const ConvolutionDimensionNumbers& dimension_numbers = conv->convolution_dimension_numbers(); @@ -492,7 +471,7 @@ Status DynamicDimensionInferenceVisitor::HandleConvolution( if (dimension == dimension_numbers.input_batch_dimension()) { parent_->SetDynamicSize(conv, {}, dimension_numbers.output_batch_dimension(), - dynamic_size, constraint); + dynamic_size); return Status::OK(); } @@ -542,20 +521,18 @@ Status DynamicDimensionInferenceVisitor::HandleConcatenate( dim_size_total, dynamic_dim)); } parent_->SetDynamicSize(hlo, {}, hlo->concatenate_dimension(), - dim_size_total, DimensionConstraint(1, 1)); + dim_size_total); } // Simply pass through non-concat dynamic dimensions. return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { int64 concatenate_dimension = hlo->concatenate_dimension(); if (concatenate_dimension == dimension) { return Status::OK(); } - parent_->SetDynamicSize(hlo, index, dimension, dynamic_size, - constraint); + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); return Status::OK(); }); } @@ -596,18 +573,15 @@ Status DynamicDimensionInferenceVisitor::HandleSetDimensionSize( if (!dimension_is_static) { // Propagate dynamic dimension indicated by this set dimension size // instruction. - parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1), - DimensionConstraint(1, 1)); + parent_->SetDynamicSize(hlo, {}, hlo->dimension(), hlo->mutable_operand(1)); } // Also Propagate dynamic dimension already set by operands. TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { if (dimension != hlo->dimension()) { - parent_->SetDynamicSize(hlo, index, dimension, dynamic_size, - constraint); + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); } return Status::OK(); })); @@ -619,10 +593,8 @@ Status DynamicDimensionInferenceVisitor::PassThroughDynamicDimension( HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { - parent_->SetDynamicSize(hlo, index, dimension, dynamic_size, - constraint); + int64 operand_index, HloInstruction* dynamic_size) { + parent_->SetDynamicSize(hlo, index, dimension, dynamic_size); return Status::OK(); }); } @@ -654,8 +626,7 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { hlo, [&](HloInstruction* operand, ShapeIndex index, int64 input_dynamic_dimension, int64 operand_index, - HloInstruction* operand_dynamic_size, - DimensionConstraint constraint) -> Status { + HloInstruction* operand_dynamic_size) -> Status { HloInstruction* reshape = hlo; if (reshape->shape().rank() == 0) { VLOG(0) << "Reshaping a dynamic dimension into a scalar, which has " @@ -751,9 +722,6 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { if (output_dynamic_dimension == -1 && output_dim_end - output_dim_start > 1) { - // TODO(yunxing): We now have a better way to decide output dimension - // in the bridge. No need for this constraint propagation logic. - // // One input dimension is splitted into multiple output dimensions. // Output dimension is decomposed from input most major dimension. // In this case, we don't know which one is dynamic, e.g., when we @@ -770,61 +738,17 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { // We use the following logics to disambiguate: // 1. If the user sets "inferred_dimension", then use that as // dynamic dimension. + // 2. If the one dimension in the reshape is dynamic, use that as + // dynamic dimension. + // E.g.: + // [<=4] + // | + // reshape + // | + // [1, <=2, 2] + // We use second dim as dynamic dimension. // - // 2. Use the "multiple_of" constraint, e.g, : - // [<=2, 4] - // | Reshape - // [<=8] - // | Reshape - // [2, 4] // Which is dynamic? - // - // If the dynamic value has to be multiple of 4 (constraint - // created by the first reshape), then 2 must be the dynamic - // dimension. - // - // But this logic doesn't help with the case where two - // dimensions are the same: - // - // [<=3, 3] - // | Reshape - // [<=9] - // | Reshape - // [3, 3] // Which is dynamic? - // - // Both dynamic dimension can be multiple of 3. - // - // We then need the next constraint to disambiguate this case: - // - // 3. Use the "stride" constraint (also see the comment at the - // definition): - // - // [<=3, 3] - // | Reshape - // [<=9] // constraint.stride = 1 - // | Reshape - // [3, 3] - // ^ ^ - // | | - // stride= 1 3 - // - // Each dimension will have different strides, only one will - // satisfy the stride constraint. - // - // Note that the stride constrint itself is not enough: - // - // - // [<=128] - // | Reshape - // [1, 128] - // ^ ^ - // | | - // stride= 1 1 - // - // In this case, both dimensions have the same stride, which is - // ambiguous. That's why we need the "multiple_of" constraint - // as used above. - // - // 4. If all logics above cannot disambiguate, e.g.,: + // 3. If all logics above cannot disambiguate, e.g.,: // // [<=1] // | @@ -833,68 +757,15 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { // [1, 1, 1] // // We bail out and return an error. + // TODO(yunxing): Further simplify this, remove 1. and fully rely + // on 2. output_dynamic_dimension = reshape->inferred_dimension(); if (output_dynamic_dimension == -1) { - // The user of XLA didn't specify a dynamic dimension, try infer - // it from the current constraint. - // - // Find all output dimensions that are decomposed from the first - // dimension. Among those dimensions, find all dimensions that - // satisfy the constraint of the dynamic dimension. In the - // previous example, if `a` is 9 and constraint is a multiple of - // `3', then in the output shape both a/c and c can be dynamic. - int64 current_product = 1; - int64 dimension_iter = output_dim_start; - - // compatible_dimensions are dimensions that satisfies - // "multiple_of" constraints. - std::vector compatible_dimensions; - while (current_product < - operand->shape().dimensions(input_dynamic_dimension)) { - current_product *= reshape->shape().dimensions(dimension_iter); - if (operand->shape().dimensions(input_dynamic_dimension) / - reshape->shape().dimensions(dimension_iter) == - constraint.multiple_of) { - compatible_dimensions.push_back(dimension_iter); + // Try find dynamic dimension from the result shape. + for (int64 i = 0; i < reshape->shape().rank(); ++i) { + if (reshape->shape().is_dynamic_dimension(i)) { + output_dynamic_dimension = i; } - dimension_iter++; - } - CHECK_EQ(current_product, - operand->shape().dimensions(input_dynamic_dimension)) - << "Not a valid reshape: " << hlo->ToString(); - // If there is only one compatible dimension, it must be the - // dynamic one in the output. - if (compatible_dimensions.size() == 1) { - output_dynamic_dimension = compatible_dimensions[0]; - } - - // When there are multiple compatible dimensions, e.g: - // [<=9] - // | Reshape - // [3, 3] - // Use stride constraint to figure out which one is the true - // dynamic one. - // - // [<=9] - // | Reshape - // [3, 3] - // ^ ^ - // | | - // stride= 1 3 - // - std::vector compatible_dimensions_with_stride; - absl::c_copy_if( - compatible_dimensions, - std::back_inserter(compatible_dimensions_with_stride), - [&](int64 dimension) { - int64 stride_total = 1; - for (int64 i = 0; i < dimension + 1; ++i) { - stride_total *= reshape->shape().dimensions(dimension); - } - return stride_total == constraint.stride; - }); - if (compatible_dimensions_with_stride.size() == 1) { - output_dynamic_dimension = compatible_dimensions_with_stride[0]; } } @@ -914,9 +785,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { return InvalidArgument( "Reshape's input dynamic dimension is decomposed into " "multiple output dynamic dimensions, but the constraint is " - "ambiguous and XLA can't infer the output dimension %s. " - "Constraint: multiple_of: %lld, stride: %lld", - hlo->ToString(), constraint.multiple_of, constraint.stride); + "ambiguous and XLA can't infer the output dimension %s. ", + hlo->ToString()); } } @@ -931,11 +801,12 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { if (input_dim_size == output_dim_size) { // Simply forward dynamic dimension. parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension, - operand_dynamic_size, constraint); + operand_dynamic_size); } if (input_dim_size > output_dim_size) { - TF_RET_CHECK(input_dim_size % output_dim_size == 0); + TF_RET_CHECK(input_dim_size % output_dim_size == 0) + << reshape->ToString(); const int64 divisor = input_dim_size / output_dim_size; HloInstruction* divisor_hlo = hlo->parent()->AddInstruction(HloInstruction::CreateConstant( @@ -946,9 +817,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { operand_dynamic_size->shape(), HloOpcode::kDivide, operand_dynamic_size, divisor_hlo)); - parent_->SetDynamicSize( - reshape, {}, output_dynamic_dimension, new_dynamic_size, - DimensionConstraint(1, constraint.multiple_of / divisor)); + parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension, + new_dynamic_size); } if (input_dim_size < output_dim_size) { @@ -985,12 +855,8 @@ Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { hlo->parent()->AddInstruction(HloInstruction::CreateBinary( output_dynamic_size->shape(), HloOpcode::kMultiply, new_dynamic_size, operand_dynamic_size)); - int64 new_multiple_of_constraint = - constraint.multiple_of * output_dim_size / - operand->shape().dimensions(input_dynamic_dimension); - parent_->SetDynamicSize( - reshape, {}, output_dynamic_dimension, new_dynamic_size, - DimensionConstraint(1, new_multiple_of_constraint)); + parent_->SetDynamicSize(reshape, {}, output_dynamic_dimension, + new_dynamic_size); } return Status::OK(); @@ -1001,8 +867,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduceWindow( HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { HloInstruction* reduce_window = hlo; const WindowDimension& window_dimension = reduce_window->window().dimensions(dimension); @@ -1013,8 +878,7 @@ Status DynamicDimensionInferenceVisitor::HandleReduceWindow( reduce_window->ToString()); } - parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size, - constraint); + parent_->SetDynamicSize(reduce_window, {}, dimension, dynamic_size); return Status::OK(); }); @@ -1024,8 +888,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { HloInstruction* select_and_scatter = hlo; const WindowDimension& window_dimension = select_and_scatter->window().dimensions(dimension); @@ -1036,8 +899,8 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( select_and_scatter->ToString()); } - parent_->SetDynamicSize(select_and_scatter, {}, dimension, dynamic_size, - constraint); + parent_->SetDynamicSize(select_and_scatter, {}, dimension, + dynamic_size); return Status::OK(); }); @@ -1046,8 +909,7 @@ Status DynamicDimensionInferenceVisitor::HandleSelectAndScatter( Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64 dimension, - int64 /*operand_index*/, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 /*operand_index*/, HloInstruction* dynamic_size) { if (hlo->slice_starts(dimension) != 0 || hlo->slice_strides(dimension) != 1 || hlo->slice_limits(dimension) != @@ -1056,7 +918,7 @@ Status DynamicDimensionInferenceVisitor::HandleSlice(HloInstruction* hlo) { return Status::OK(); } - parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint); + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); return Status::OK(); }); @@ -1066,8 +928,7 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicSlice( HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction*, ShapeIndex /*index*/, int64 dimension, - int64 /*operand_index*/, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 /*operand_index*/, HloInstruction* dynamic_size) { if (hlo->shape().dimensions(dimension) != hlo->operand(0)->shape().dimensions(dimension)) { // Slicing a single element out kills the dynamic dimension. @@ -1080,7 +941,7 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicSlice( hlo->ToString()); } - parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint); + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); return Status::OK(); }); @@ -1089,9 +950,9 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicSlice( Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice( HloInstruction* hlo) { return ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, - int64 dimension, int64 /*operand_index*/, - HloInstruction* dynamic_size, DimensionConstraint constraint) { + hlo, + [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension, + int64 /*operand_index*/, HloInstruction* dynamic_size) { if (hlo->shape().dimensions(dimension) != hlo->operand(0)->shape().dimensions(dimension)) { return Unimplemented( @@ -1100,7 +961,7 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice( hlo->ToString()); } - parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint); + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); return Status::OK(); }); @@ -1108,16 +969,16 @@ Status DynamicDimensionInferenceVisitor::HandleDynamicUpdateSlice( Status DynamicDimensionInferenceVisitor::HandleReverse(HloInstruction* hlo) { return ForEachOperandDynamicDimension( - hlo, [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, - int64 dimension, int64 /*operand_index*/, - HloInstruction* dynamic_size, DimensionConstraint constraint) { + hlo, + [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension, + int64 /*operand_index*/, HloInstruction* dynamic_size) { if (absl::c_linear_search(hlo->dimensions(), dimension)) { return Unimplemented( "Dynamic dimension propagation on reversed dimension is not " "supported %s", hlo->ToString()); } - parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size, constraint); + parent_->SetDynamicSize(hlo, {}, dimension, dynamic_size); return Status::OK(); }); @@ -1127,7 +988,7 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* operand, ShapeIndex /*index*/, int64 input_dynamic_dimension, int64 operand_index, - HloInstruction* dynamic_size, DimensionConstraint constraint) { + HloInstruction* dynamic_size) { const GatherDimensionNumbers& gather_dims = hlo->gather_dimension_numbers(); if (operand_index != 1) { @@ -1147,8 +1008,7 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) { output_dimension--; } } - parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size, - constraint); + parent_->SetDynamicSize(hlo, {}, output_dimension, dynamic_size); return Status::OK(); } return Unimplemented( @@ -1171,8 +1031,7 @@ Status DynamicDimensionInferenceVisitor::HandleGather(HloInstruction* hlo) { indices_dim++; } if (indices_dim++ == input_dynamic_dimension) { - parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size, - constraint); + parent_->SetDynamicSize(hlo, {}, output_dim, dynamic_size); return Status::OK(); } } @@ -1220,8 +1079,7 @@ Status DynamicDimensionInferenceVisitor::HandleConditional( TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand( hlo, operand_index, [&](HloInstruction*, ShapeIndex, int64, int64, - HloInstruction* dynamic_size, - DimensionConstraint constraint) -> Status { + HloInstruction* dynamic_size) -> Status { TF_RET_CHECK(hlo->operand(operand_index)->shape().IsTuple()) << "Only tuple typed inputs can have dynamic dimension. Please " "file a bug against XLA team."; @@ -1263,8 +1121,7 @@ Status DynamicDimensionInferenceVisitor::HandleConditional( TF_RETURN_IF_ERROR(ForEachDynamicDimensionInOperand( hlo, operand_index, [&](HloInstruction*, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* dynamic_size) { DynamicParameterBinding::DynamicParameter dynamic_parameter{ 0, {dynamic_size_to_operand_id_index_map[dynamic_size]}}; DynamicParameterBinding::DynamicDimension dynamic_dimension{ @@ -1284,8 +1141,8 @@ Status DynamicDimensionInferenceVisitor::HandleConditional( // that into the root instruction as additional tuple elements. TF_RETURN_IF_ERROR(ForEachDynamicDimension( new_computation->root_instruction(), - [&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size, - DimensionConstraint) -> Status { + [&](ShapeIndex index, int64 dim, + HloInstruction* dynamic_size) -> Status { TF_RET_CHECK(hlo->shape().IsTuple()) << "Only tuple typed conditionals can have dynamic dimension. " "Please file a bug against XLA team."; @@ -1347,11 +1204,9 @@ Status DynamicDimensionInferenceVisitor::HandleScatter(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, [&](HloInstruction* /*operand*/, ShapeIndex /*index*/, int64 dimension, - int64 operand_index, HloInstruction* operand_dynamic_size, - DimensionConstraint constraint) { + int64 operand_index, HloInstruction* operand_dynamic_size) { if (operand_index == 0) { - parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size, - constraint); + parent_->SetDynamicSize(hlo, {}, dimension, operand_dynamic_size); return Status::OK(); } @@ -1385,7 +1240,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) { int64 operand_count = original_tuple_count; TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction*, ShapeIndex index, int64 dim, int64, - HloInstruction* dynamic_size, DimensionConstraint constraint) { + HloInstruction* dynamic_size) { operands_to_add.push_back(dynamic_size); dynamic_output_mapping.mutable_element(index)->emplace(dim, operand_count++); @@ -1413,8 +1268,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) { TF_RETURN_IF_ERROR(ForEachOperandDynamicDimension( hlo, [&](HloInstruction*, ShapeIndex index, int64 dimension, - int64 operand_index, HloInstruction* dynamic_size, - DimensionConstraint constraint) -> Status { + int64 operand_index, HloInstruction* dynamic_size) -> Status { TF_RET_CHECK(!operands_to_add.empty()); const int64 output_dynamic_size_index = dynamic_output_mapping.element(index).at(dimension); @@ -1431,7 +1285,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) { ShapeUtil::MakeScalarShape(S32), hlo, output_dynamic_size_index)); parent_->SetDynamicSize(result.replacement_instr, index, dimension, - output_dynamic_size, constraint); + output_dynamic_size); return Status::OK(); })); // Set the replacement instruction as visited to avoid visiting it again. @@ -1465,8 +1319,7 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) { // Add dynamic dimension size as new parameters. TF_RETURN_IF_ERROR(ForEachDynamicDimension( hlo->while_body()->root_instruction(), - [&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size, - DimensionConstraint) -> Status { + [&](ShapeIndex index, int64 dim, HloInstruction* dynamic_size) -> Status { const int64 output_index = dynamic_output_mapping.element(index).at(dim); new_root_operands[output_index] = dynamic_size; @@ -1503,8 +1356,7 @@ Status DynamicDimensionInferenceVisitor::HandleParameter(HloInstruction* hlo) { parent_->SetDynamicSize(target_parameter, dynamic_dimension.parameter_index, - dynamic_dimension.dimension, dynamic_size, - DimensionConstraint(1, 1)); + dynamic_dimension.dimension, dynamic_size); return Status::OK(); }); } @@ -1517,10 +1369,8 @@ Status DynamicDimensionInferenceVisitor::ForEachDynamicDimension( HloInstruction* dynamic_size = parent_->GetDynamicSize( dynamic_dimension.inst, dynamic_dimension.index, dynamic_dimension.dim); - CHECK_NE(parent_->constraint_mapping_.count(dynamic_dimension), 0); - TF_RETURN_IF_ERROR(fn(dynamic_dimension.index, dynamic_dimension.dim, - dynamic_size, - parent_->constraint_mapping_[dynamic_dimension])); + TF_RETURN_IF_ERROR( + fn(dynamic_dimension.index, dynamic_dimension.dim, dynamic_size)); } } return Status::OK(); @@ -1536,10 +1386,9 @@ Status DynamicDimensionInferenceVisitor::ForEachDynamicDimensionInOperand( HloInstruction* dynamic_size = parent_->GetDynamicSize( dynamic_dimension.inst, dynamic_dimension.index, dynamic_dimension.dim); - CHECK_NE(parent_->constraint_mapping_.count(dynamic_dimension), 0); TF_RETURN_IF_ERROR(fn(dynamic_dimension.inst, dynamic_dimension.index, - dynamic_dimension.dim, operand_index, dynamic_size, - parent_->constraint_mapping_[dynamic_dimension])); + dynamic_dimension.dim, operand_index, + dynamic_size)); } } return Status::OK(); @@ -1555,6 +1404,24 @@ Status DynamicDimensionInferenceVisitor::ForEachOperandDynamicDimension( return Status::OK(); } +void DynamicDimensionInference::SetDynamicSize(HloInstruction* inst, + const ShapeIndex& index, + int64 dim, + HloInstruction* size) { + VLOG(1) << "Set dimension inst " << inst->ToString() << " index " + << index.ToString() << "@" << dim << " to " << size->ToShortString(); + Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index); + CHECK(!subshape.IsTuple()) << "Can't set a tuple shape to dynamic dimension"; + CHECK(dim < subshape.rank() && dim >= 0) + << "Asked to set invalid dynamic dimension. Shape: " + << subshape.ToString() << ", Dimension: " << dim; + DynamicDimension dynamic_dimension{inst, index, dim}; + // Updating a dynamic dimension twice overwrites the previous one. + dynamic_mapping_[dynamic_dimension] = size; + auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst); + iter.first->second.emplace(dynamic_dimension); +} + void DynamicDimensionInference::CopyMapping(HloInstruction* from, HloInstruction* to) { auto iter = per_hlo_dynamic_dimensions_.find(from); @@ -1564,7 +1431,7 @@ void DynamicDimensionInference::CopyMapping(HloInstruction* from, GetDynamicSize(dynamic_dimension.inst, dynamic_dimension.index, dynamic_dimension.dim); SetDynamicSize(to, dynamic_dimension.index, dynamic_dimension.dim, - dynamic_size, constraint_mapping_[dynamic_dimension]); + dynamic_size); } } } @@ -1624,8 +1491,6 @@ Status DynamicDimensionInference::ForwardDynamicSize(HloInstruction* inst, auto iter = dynamic_mapping_.find(dynamic_dimension); if (iter != dynamic_mapping_.end()) { dynamic_mapping_.insert({dynamic_dimension_new, iter->second}); - constraint_mapping_.insert( - {dynamic_dimension_new, constraint_mapping_[dynamic_dimension]}); auto iter = per_hlo_dynamic_dimensions_.try_emplace(new_inst); iter.first->second.emplace(dynamic_dimension_new); } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h index 607d68bd9c3..1597538e9ac 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.h +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.h @@ -55,8 +55,7 @@ class DynamicDimensionInference { // go into tuples. bool HasDynamicDimension(HloInstruction* inst) const; - // Forward dynamic dimension size at `dim` and its constraint from `inst` to - // `new_inst`. + // Forward dynamic dimension size at `dim` from `inst` to `new_inst`. Status ForwardDynamicSize(HloInstruction* inst, HloInstruction* new_inst, const ShapeIndex& index); @@ -64,9 +63,7 @@ class DynamicDimensionInference { // `inst` at `index` has a dynamic size, and its runtime size is represented // by a scalar instruction `size`. void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, - HloInstruction* size) { - SetDynamicSize(inst, index, dim, size, DimensionConstraint(1, 1)); - } + HloInstruction* size); // For all tensors whose dynamic dimension is `replace`, replace them with // `with`. @@ -106,116 +103,6 @@ class DynamicDimensionInference { } }; - // DimensionConstraint is attached to each dynamic dimension and describe the - // constraint of each dimension. This is used to disambiguate the index of - // dynamic dimension for reshapes that "splits" a dimension into two. - // - // As an example, consider the following reshapes: - // [<=3, 3] <- Assume first dimension is dynamic. - // | - // Reshape.1 - // | - // [<=9] <- Dimension 9 is dynamic - // | - // Reshape.2 - // | - // [3, 3] <- Ambiguous dimension after splitting 9 into [3, 3] - // - // There is no way to know which dimension is dynamic by looking at the second - // reshape locally. - // - // However, if we look at the dynamic dimension 9, since it comes from - // collapsing a major dynamic dimension of 3 (the dynamic size can be 0, 1, 2, - // 3, denoted as i in the diagram below) and a minor static dimension of 3, we - // know it has certain constraints that the reshape can only be one of the 4 - // forms: - // - // o: Padded Data - // x: Effective Data - // - // [<=3, 3] to [9] - // - // +---+ +---+ +---+ +---+ - // |ooo| |ooo| |ooo| |xxx| - // |ooo| |ooo| |xxx| |xxx| - // |ooo| |xxx| |xxx| |xxx| - // +---+ +---+ +---+ +---+ - // - // Reshape Reshape Reshape Reshape - // - // +-----------+ +-----------+ +-----------+ +-----------+ - // |ooo|ooo|ooo| or |xxx|ooo|ooo| or |xxx|xxx|ooo| or |xxx|xxx|xxx| stride=1 - // +-----------+ +-----------+ +-----------+ +-----------+ - // i = 0 i = 1 i = 2 i = 3 - // - // On the other hand, if the minor dimension 3 is dynamic and major dimension - // is static, we will have the following form: - // - // [3, <=3] to [9] - // - // +---+ +---+ +---+ +---+ - // |ooo| |xoo| |xxo| |xxx| - // |ooo| |xoo| |xxo| |xxx| - // |ooo| |xoo| |xxo| |xxx| - // +---+ +---+ +---+ +---+ - // - // Reshape Reshape Reshape Reshape - // - // +-----------+ +-----------+ +-----------+ +-----------+ - // |ooo|ooo|ooo| or |xoo|xoo|xoo| or |xxo|xxo|xxo| or |xxo|xxo|xxo| stride=3 - // +-----------+ +-----------+ +-----------+ +-----------+ - // i = 0 i = 1 i = 2 i = 3 - // - // By encoding constraint as a stride of elements we can recover this - // information later when we reshape from [9] to [3, 3]. We know which form - // ([3, i] or [i,3]) we should reshape the [9] into. - // - // - struct DimensionConstraint { - explicit DimensionConstraint(int64 s, int64 m) - : stride(s), multiple_of(m) {} - DimensionConstraint() : stride(1), multiple_of(1) {} - // Stride represents the distance of a newly placed element and the previous - // placed element on this dynamic dimension. - int64 stride; - - // multiple_of represents the constraints that - // - // `dynamic_size` % `multiple_of` == 0 - int64 multiple_of; - }; - - using ConstraintMapping = - absl::flat_hash_map; - - ConstraintMapping constraint_mapping_; - - // Update the dynamic mapping so that we know dimension `dim` of instruction - // `inst` at `index` has a dynamic size, and its runtime size is represented - // by a scalar instruction `size`. - void SetDynamicSize(HloInstruction* inst, const ShapeIndex& index, int64 dim, - HloInstruction* size, DimensionConstraint constraint) { - VLOG(1) << "Set dimension inst " << inst->ToString() << " index " - << index.ToString() << "@" << dim << " to " << size->ToShortString() - << " constraint: " << constraint.multiple_of; - Shape subshape = ShapeUtil::GetSubshape(inst->shape(), index); - CHECK(!subshape.IsTuple()) - << "Can't set a tuple shape to dynamic dimension"; - CHECK(dim < subshape.rank() && dim >= 0) - << "Asked to set invalid dynamic dimension. Shape: " - << subshape.ToString() << ", Dimension: " << dim; - DynamicDimension dynamic_dimension{inst, index, dim}; - // Updating a dynamic dimension twice overwrites the previous one. - dynamic_mapping_[dynamic_dimension] = size; - if (constraint_mapping_.count(dynamic_dimension) != 0) { - CHECK_EQ(constraint_mapping_[dynamic_dimension].stride, - constraint.stride); - } - constraint_mapping_[dynamic_dimension] = constraint; - auto iter = per_hlo_dynamic_dimensions_.try_emplace(inst); - iter.first->second.emplace(dynamic_dimension); - } - // Copies the internal mapping from instruction `from` to instruction `to`. // This is useful when an instruction is replaced by the other during the // inferencing process. diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index 44fdda0f411..c1f9da599e8 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -688,9 +688,7 @@ StatusOr RewriteDynamicConcat( dynamic_size)); } } - for (HloInstruction* user : prev_users) { - TF_RETURN_IF_ERROR(concat->ReplaceUseWith(user, rewritten_concat)); - } + TF_RETURN_IF_ERROR(concat->ReplaceUsesWith(prev_users, rewritten_concat)); TF_RETURN_IF_ERROR(dynamic_dimension_inference->ForwardDynamicSize( concat, rewritten_concat, {})); return true; diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index e4c70317f2b..e8f429d9db6 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -83,8 +83,8 @@ class DynamicPadderTest : public HloTestBase { return module; } - StatusOr RunPadder() { - DynamicPadder padder(/*slice_dynamic_output=*/true, + StatusOr RunPadder(bool slice_dynamic_output = false) { + DynamicPadder padder(/*slice_dynamic_output=*/slice_dynamic_output, CustomCallDynamicDimensionInference, OpHasDynamismSupport); return padder.Run(module_.get()); @@ -162,7 +162,7 @@ ENTRY main { module_ = GetHloModule(hlo_text); - TF_ASSERT_OK(RunPadder().status()); + TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status()); // After rewrite, we should have : // // param @@ -218,7 +218,7 @@ ENTRY main { module_ = GetHloModule(hlo_text); - TF_ASSERT_OK(RunPadder().status()); + TF_ASSERT_OK(RunPadder(/*slice_dynamic_output=*/true).status()); // After rewrite, we should have : // // param @@ -654,26 +654,16 @@ XLA_TEST_F(ExecutionTest, DynamicConcat) { const string hlo_text = R"( HloModule DynamicConcat -update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { - lhs = s32[] parameter(0) - rhs = s32[] parameter(1) - ROOT add = s32[] add(lhs, rhs) -} - ENTRY main { param_0 = s32[3] parameter(0) param_1 = s32[3] parameter(1) param_2 = s32[3] parameter(2) size = s32[] constant(2) - param_padded_0 = s32[3] set-dimension-size(param_0, size), dimensions={0} - param_padded_2 = s32[3] set-dimension-size(param_2, size), dimensions={0} - %concatenate = s32[9] - concatenate(s32[3] param_padded_0, s32[3] param_1, s32[3] param_padded_2), + param_padded_0 = s32[<=3] set-dimension-size(param_0, size), dimensions={0} + param_padded_2 = s32[<=3] set-dimension-size(param_2, size), dimensions={0} + ROOT %concatenate = s32[9] + concatenate(s32[<=3] param_padded_0, s32[<=3] param_1, s32[<=3] param_padded_2), dimensions={0} - init = s32[] constant(0) - ROOT reduce = s32[] reduce(concatenate, init), - dimensions={0}, - to_apply=update_s32 } )"; @@ -686,10 +676,10 @@ ENTRY main { LiteralUtil::CreateR1({6, 7, -1}); // Dynamic operand. auto module = GetHloModule(hlo_text); - Literal result = - PadAndExecute(std::move(module), {&operand_0, &operand_1, &operand_2}); - - Literal expected = LiteralUtil::CreateR0(28); + Literal result = PadAndExecute(std::move(module), + {&operand_0, &operand_1, &operand_2}, false); + result.SetDynamicSize(0, 7); + Literal expected = LiteralUtil::CreateR1({1, 2, 3, 4, 5, 6, 7}); EXPECT_EQ(result, expected); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 4b6c30cadc4..98d523487b4 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -2462,10 +2462,6 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( operand_to_generator.at(hlo->operand(i))(index)); operands.push_back(operand_value); } - std::vector input_generators; - for (const HloInstruction* instr : hlo->operands()) { - input_generators.push_back(operand_to_generator.at(instr)); - } return EmitElementalMap(Cast(hlo), operands); }; case HloOpcode::kReduceWindow: diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 61ce6200a28..d5cf2ee9ac0 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -93,7 +93,8 @@ StatusOr Executable::ExecuteOnStream( static ExecutionInput MakeMaybeOwningDeviceMemoryTree( const ShapedBuffer& shaped_buffer) { - ExecutionInput result(shaped_buffer.on_device_shape()); + ExecutionInput result(shaped_buffer.on_device_shape(), + shaped_buffer.on_host_shape()); shaped_buffer.buffers().ForEachElement( [&](const ShapeIndex& index, const se::DeviceMemoryBase& mem) { result.SetBuffer(index, MaybeOwningDeviceMemory(mem)); @@ -105,10 +106,10 @@ StatusOr Executable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, HloExecutionProfile* hlo_execution_profile) { - std::vector args(arguments.size()); - auto out_it = args.begin(); + std::vector args; + args.reserve(arguments.size()); for (const ShapedBuffer* arg : arguments) { - *out_it++ = MakeMaybeOwningDeviceMemoryTree(*arg); + args.emplace_back(MakeMaybeOwningDeviceMemoryTree(*arg)); } TF_ASSIGN_OR_RETURN(ExecutionOutput out, ExecuteAsyncOnStream(run_options, std::move(args), diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 6881f6dd68a..2e3ddedfb8c 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -60,10 +60,17 @@ namespace xla { // with their indices absent from unowned_indices_. class ExecutionInput { public: - ExecutionInput() = default; - explicit ExecutionInput(xla::Shape shape) : buffers_(std::move(shape)) {} - explicit ExecutionInput(ShapeTree buffers) - : buffers_(std::move(buffers)) {} + explicit ExecutionInput(xla::Shape shape, xla::Shape host_shape) + : buffers_(std::move(shape)) { + SetHostShape(std::move(host_shape)); + } + + explicit ExecutionInput(ShapeTree buffers, + xla::Shape host_shape) + : buffers_(std::move(buffers)) { + SetHostShape(std::move(host_shape)); + } + ExecutionInput(ExecutionInput&&) = default; ~ExecutionInput(); @@ -74,6 +81,10 @@ class ExecutionInput { return dynamic_shape_ != nullptr ? *dynamic_shape_ : buffers_.shape(); } + const Shape& host_shape() const { + return host_shape_ != nullptr ? *host_shape_ : shape(); + } + Status SetDynamicShape(Shape dynamic_shape); xla::StatusOr ToShapedBuffer( @@ -94,6 +105,8 @@ class ExecutionInput { unowned_indices_.erase(index); } + const std::set& unowned_indices() { return unowned_indices_; } + const ShapeTree& Buffers() const { return buffers_; } ShapeTree* MutableBuffers() { return &buffers_; } @@ -107,11 +120,18 @@ class ExecutionInput { } private: + void SetHostShape(xla::Shape host_shape) { + if (shape() != host_shape) { + host_shape_ = absl::make_unique(std::move(host_shape)); + } + } + ShapeTree buffers_; // Set of indices of buffers that should be returned to the caller if an error // occurs when enqueuing the computation. std::set unowned_indices_; std::unique_ptr dynamic_shape_; + std::unique_ptr host_shape_; }; // ExecutionOutput encapsulates the output buffers of a execution and the @@ -172,6 +192,12 @@ class ExecutionOutput { return std::move(to_be_released_); } + std::vector ConsumeAliasedIndices() { + auto aliased = std::move(aliased_indices_); + aliased_indices_.clear(); + return aliased; + } + private: ScopedShapedBuffer result_; diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 1838f65e6ea..d38873a501d 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -269,6 +269,22 @@ static StatusOr PermuteBatchAndOffsetDims( return MakeTransposeHlo(accumulator, permutation); } +// Computes how many trips a loop implementing this gather op would take. +static int64 GatherLoopTripCount(HloInstruction* gather_instr) { + HloInstruction* start_indices = gather_instr->mutable_operand(1); + const Shape& start_indices_shape = start_indices->shape(); + const GatherDimensionNumbers& dim_numbers = + gather_instr->gather_dimension_numbers(); + + int64 trip_count = 1; + for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + trip_count *= start_indices_shape.dimensions(i); + } + } + return trip_count; +} + // High Level Algorithm // // We follow the following steps in sequence: @@ -311,20 +327,13 @@ StatusOr GatherExpander::ExpandInstruction( HloComputation* computation = gather_instr->parent(); HloInstruction* operand = gather_instr->mutable_operand(0); HloInstruction* start_indices = gather_instr->mutable_operand(1); - const Shape& start_indices_shape = start_indices->shape(); const Shape& output_shape = gather_instr->shape(); int64 output_rank = output_shape.dimensions_size(); const GatherDimensionNumbers& dim_numbers = gather_instr->gather_dimension_numbers(); - int64 gather_loop_trip_count = 1; - for (int64 i = 0, e = start_indices_shape.dimensions_size(); i < e; i++) { - if (i != dim_numbers.index_vector_dim()) { - gather_loop_trip_count *= start_indices_shape.dimensions(i); - } - } - + int64 gather_loop_trip_count = GatherLoopTripCount(gather_instr); if (!IsInt32(gather_loop_trip_count)) { return Unimplemented( "Gather operations with more than 2147483647 gather indices are not " @@ -373,7 +382,11 @@ bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) { return inst->opcode() == HloOpcode::kGather && // Avoid expanding gather ops that produce zero sized tensors, // instead punt these to ZeroSizedHloElimination. - !ShapeUtil::IsZeroElementArray(inst->shape()); + !ShapeUtil::IsZeroElementArray(inst->shape()) && + // In kEliminateSimpleGathers mode, we only simplify instructions + // which can be represented without a loop -- i.e. we only simplify + // gathers which have a trip count of 1. + (mode_ == kEliminateAllGathers || GatherLoopTripCount(inst) == 1); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 5625a37cb46..e665fcd713c 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -21,10 +21,30 @@ limitations under the License. namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic -// slices. This lets backends that don't support gather directly to -// nevertheless have a minimum level of support. +// slices. +// +// This pass can be used two ways: +// +// - kEliminateAllGathers: For backends that don't support gather, this pass +// can convert every gather to a loop. +// +// - kEliminateSimpleGathers: For backends that *do* support gather, this pass +// can strength-reduce "simple" gathers -- specifically, gathers that can be +// represented without a loop -- to dyanmic-slices. +// +// Note that even in kEliminateSimpleGathers mode, this pass may still expand a +// gather into a loop (with a trip-count of 1). It's up to other simplification +// passes to remove the loop. +// class GatherExpander : public OpExpanderPass { public: + enum Mode { + kEliminateAllGathers, + kEliminateSimpleGathers, + }; + + explicit GatherExpander(Mode m) : mode_(m) {} + absl::string_view name() const override { return "gather_expander"; } protected: @@ -32,6 +52,9 @@ class GatherExpander : public OpExpanderPass { StatusOr ExpandInstruction( HloInstruction* gather_inst) override; + + private: + Mode mode_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 706327091d9..4b0808e9aaf 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -42,7 +43,9 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - Status status = GatherExpander{}.Run(module.get()).status(); + Status status = GatherExpander{GatherExpander::kEliminateAllGathers} + .Run(module.get()) + .status(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); ASSERT_THAT( @@ -68,7 +71,9 @@ ENTRY main { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get())); ASSERT_TRUE(changed); HloInstruction* while_instr = nullptr; @@ -129,7 +134,9 @@ ENTRY main { OpMetadata metadata; metadata.set_op_name("Gather"); module->entry_computation()->root_instruction()->set_metadata(metadata); - TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get())); ASSERT_TRUE(changed); HloInstruction* while_instr = nullptr; @@ -147,5 +154,54 @@ ENTRY main { "after gather expansion"; EXPECT_EQ(while_instr->metadata().op_name(), "Gather"); } + +TEST_F(GatherExpanderTest, EliminateSimpleGathersSkipsNontrivialGather) { + const string hlo_text = R"( +HloModule TensorFlowGatherV1 + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + ROOT gather = s32[2,3] gather(operand, indices), + offset_dims={1}, + collapsed_slice_dims={0}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1, 3} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GatherExpander pass(GatherExpander::kEliminateSimpleGathers); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get())); + ASSERT_FALSE(changed); +} + +TEST_F(GatherExpanderTest, EliminateSimpleGathersRewritesTrivialGather) { + const string hlo_text = R"( +HloModule test + +ENTRY main { + operand = s32[100] parameter(0) + indices = s32[1] parameter(1) + ROOT gather = s32[10] gather(operand, indices), + offset_dims={0}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=0, + slice_sizes={10} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + GatherExpander pass(GatherExpander::kEliminateAllGathers); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get())); + ASSERT_TRUE(changed); + ASSERT_FALSE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(), + {HloOpcode::kGather})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index b22f258bac6..074fbd92b27 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -27,6 +27,7 @@ load( "if_cuda_is_configured", ) load("//tensorflow:tensorflow.bzl", "if_nccl") +load("//third_party/mlir:tblgen.bzl", "gentbl") package( default_visibility = [":friends"], @@ -170,7 +171,6 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", "//tensorflow/compiler/xla/service/llvm_ir:buffer_assignment_util", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", @@ -286,6 +286,7 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", @@ -687,7 +688,7 @@ cc_library( ":gpu_autotuning_proto_cc", ":gpu_conv_runner", ":gpu_executable", - ":hlo_algorithm_blacklist", + ":hlo_algorithm_denylist", ":ir_emission_utils", ":stream_executor_util", "@com_google_absl//absl/algorithm:container", @@ -1168,6 +1169,8 @@ cc_library( "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:comparison_expander", + "//tensorflow/compiler/xla/service:conditional_canonicalizer", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_4d_expander", "//tensorflow/compiler/xla/service:dot_decomposer", @@ -1176,6 +1179,7 @@ cc_library( "//tensorflow/compiler/xla/service:dynamic_padder", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", + "//tensorflow/compiler/xla/service:gather_expander", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", @@ -1660,9 +1664,9 @@ tf_proto_library_cc( ) cc_library( - name = "hlo_algorithm_blacklist", - srcs = ["hlo_algorithm_blacklist.cc"], - hdrs = ["hlo_algorithm_blacklist.h"], + name = "hlo_algorithm_denylist", + srcs = ["hlo_algorithm_denylist.cc"], + hdrs = ["hlo_algorithm_denylist.h"], deps = [ ":gpu_autotuning_proto_cc", "//tensorflow/compiler/xla:debug_options_flags", @@ -1673,12 +1677,12 @@ cc_library( ) tf_cc_test( - name = "hlo_algorithm_blacklist_test", - srcs = ["hlo_algorithm_blacklist_test.cc"], - data = ["data/hlo_algorithm_blacklist.pbtxt"], + name = "hlo_algorithm_denylist_test", + srcs = ["hlo_algorithm_denylist_test.cc"], + data = ["data/hlo_algorithm_denylist.pbtxt"], tags = ["no_pip"], deps = [ - ":hlo_algorithm_blacklist", + ":hlo_algorithm_denylist", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -1875,3 +1879,49 @@ cc_library( "@com_google_absl//absl/types:span", ], ) + +gentbl( + name = "xla_thunks_ops_inc_gen", + tbl_outs = [ + ("-gen-op-decls", "ir/xla_thunks_ops.h.inc"), + ("-gen-op-defs", "ir/xla_thunks_ops.cc.inc"), + ("-gen-struct-attr-decls", "ir/xla_thunks_structs.h.inc"), + ("-gen-struct-attr-defs", "ir/xla_thunks_structs.cc.inc"), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "ir/xla_thunks_ops.td", + td_srcs = [ + "@llvm-project//mlir:LLVMOpsTdFiles", + ], +) + +cc_library( + name = "xla_thunks_ops", + srcs = [ + "ir/xla_thunks_ops.cc", + "ir/xla_thunks_ops.cc.inc", + "ir/xla_thunks_ops.h.inc", + ], + hdrs = [ + "ir/xla_thunks_ops.h", + ], + deps = [ + ":xla_thunks_ops_inc_gen", + "//tensorflow/compiler/mlir/hlo", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + ], +) + +# Library with XLA thunks dialect static initialization. +cc_library( + name = "xla_thunks_dialect_registration", + srcs = [ + "ir/dialect_registration.cc", + ], + deps = [ + ":xla_thunks_ops", + "@llvm-project//mlir:IR", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index bb76bf02eba..b3b5cf7e048 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -220,10 +220,13 @@ RefcountingHashMap& GlobalRendezvousMap() { CollectivePermuteThunk::CollectivePermuteThunk( ThunkInfo thunk_info, const BufferAllocation::Slice& src, const BufferAllocation::Slice& dest) - : Thunk(kCollectivePermute, thunk_info), src_(src), dest_(dest) {} + : Thunk(kCollectivePermute, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), + src_(src), + dest_(dest) {} Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { - auto* instr = Cast(hlo_instruction()); + auto* instr = Cast(hlo_instruction_); auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h index 329db00c66a..44cc6a1c64e 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h @@ -33,6 +33,7 @@ class CollectivePermuteThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; BufferAllocation::Slice src_; BufferAllocation::Slice dest_; }; diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 041aa9b6fa3..4cff48a89da 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -29,6 +29,7 @@ ConditionalThunk::ConditionalThunk( absl::Span branch_operand_buffer_indexes, std::vector branch_thunk_sequences) : Thunk(Kind::kConditional, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), branch_index_is_bool_( thunk_info.hlo_instruction->operand(0)->shape().element_type() == PRED), @@ -45,13 +46,6 @@ ConditionalThunk::ConditionalThunk( } } -void ConditionalThunk::ComputeAnnotations() { - Thunk::ComputeAnnotations(); - for (auto& branch_thunk : branch_thunks_) { - branch_thunk->ComputeAnnotations(); - } -} - Status ConditionalThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { if (branch_index_is_bool_) { @@ -91,8 +85,8 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { branch_index = pred ? 0 : 1; } else { // Handle default scenario for branch_index not in [0, num_branches). - if (branch_index < 0 || branch_index >= hlo_instruction()->branch_count()) { - branch_index = hlo_instruction()->branch_count() - 1; + if (branch_index < 0 || branch_index >= hlo_instruction_->branch_count()) { + branch_index = hlo_instruction_->branch_count() - 1; } } @@ -100,7 +94,7 @@ Status ConditionalThunk::ExecuteOnStream(const ExecuteParams& params) { profiler.StartHloComputation(); TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream(params)); profiler.FinishHloComputation( - hlo_instruction()->branch_computation(branch_index)); + hlo_instruction_->branch_computation(branch_index)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h index a00285efa7c..f91f1c52146 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -51,12 +51,12 @@ class ConditionalThunk : public Thunk { ConditionalThunk(const ConditionalThunk&) = delete; ConditionalThunk& operator=(const ConditionalThunk&) = delete; - void ComputeAnnotations() override; Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; const bool branch_index_is_bool_; BufferAllocation::Slice branch_index_buffer_index_; std::vector branch_operand_buffer_indexes_; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index df3dd6d4593..3048db95c39 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -35,12 +35,11 @@ ConvolutionThunk::ConvolutionThunk( BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, BufferAllocation::Slice tuple_result_slice) : Thunk(Kind::kConvolution, thunk_info), + cudnn_call_(Cast(thunk_info.hlo_instruction)), operand_buffers_(std::move(operand_slices)), result_buffer_(result_slice), scratch_buffer_(scratch_slice), - tuple_result_buffer_(tuple_result_slice) { - cudnn_call_ = Cast(hlo_instruction()); -} + tuple_result_buffer_(tuple_result_slice) {} Status ConvolutionThunk::ExecuteOnStream(const ExecuteParams& params) { const auto& buffer_allocations = *params.buffer_allocations; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 36f415d9d89..e91b2c4d0d2 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -98,6 +98,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( const BufferAllocation::Slice& variance, float epsilon, int64 feature_index, const BufferAllocation::Slice& output) : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), operand_(operand), scale_(scale), offset_(offset), @@ -106,7 +107,7 @@ CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( epsilon_(epsilon), feature_index_(feature_index), output_(output) { - const auto* hlo = hlo_instruction(); + const auto* hlo = hlo_instruction_; CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardInferenceCallTarget); @@ -130,7 +131,7 @@ Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( buffer_allocations.GetDeviceAddress(variance_)); auto& stream = *params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardInference( - hlo_instruction(), operand, output_base, scale, offset, mean, variance, + hlo_instruction_, operand, output_base, scale, offset, mean, variance, epsilon_, feature_index_, &stream)); if (!stream.ok()) { @@ -148,6 +149,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( const BufferAllocation::Slice& output_inv_stddev, const BufferAllocation::Slice& output_tuple) : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), operand_(operand), scale_(scale), offset_(offset), @@ -157,7 +159,7 @@ CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( output_mean_(output_mean), output_inv_stddev_(output_inv_stddev), output_tuple_(output_tuple) { - const auto* hlo = hlo_instruction(); + const auto* hlo = hlo_instruction_; CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget); CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); @@ -183,7 +185,7 @@ Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( params.profiler->MakeScopedInstructionProfiler(profile_index()); auto& stream = *params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormForwardTraining( - hlo_instruction(), operand, output_data, output_mean, output_inv_stddev, + hlo_instruction_, operand, output_data, output_mean, output_inv_stddev, se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), epsilon_, feature_index_, &stream)); @@ -214,6 +216,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( const BufferAllocation::Slice& output_grad_offset, const BufferAllocation::Slice& output_tuple) : Thunk(Thunk::Kind::kCudnnBatchNormBackward, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), operand_(operand), scale_(scale), mean_(mean), @@ -225,7 +228,7 @@ CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( output_grad_scale_(output_grad_scale), output_grad_offset_(output_grad_offset), output_tuple_(output_tuple) { - const auto* hlo = hlo_instruction(); + const auto* hlo = hlo_instruction_; CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget); CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); @@ -253,7 +256,7 @@ Status CudnnBatchNormBackwardThunk::ExecuteOnStream( params.profiler->MakeScopedInstructionProfiler(profile_index()); se::Stream* stream = params.stream; TF_RETURN_IF_ERROR(RunCudnnBatchNormBackward( - hlo_instruction(), operand, output_grad_data, grad_output, + hlo_instruction_, operand, output_grad_data, grad_output, output_grad_scale, output_grad_offset, se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), se::DeviceMemory(buffer_allocations.GetDeviceAddress(mean_)), diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h index 5897435a58f..bb46017b8fb 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -63,6 +63,7 @@ class CudnnBatchNormForwardInferenceThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice offset_; @@ -92,6 +93,7 @@ class CudnnBatchNormForwardTrainingThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice offset_; @@ -124,6 +126,7 @@ class CudnnBatchNormBackwardThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; BufferAllocation::Slice operand_; BufferAllocation::Slice scale_; BufferAllocation::Slice mean_; diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc index 16a1f923c91..dae15659402 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.cc @@ -26,11 +26,12 @@ CustomCallThunk::CustomCallThunk( std::vector> operand_slices, ShapeTree result_slices, std::string opaque) : Thunk(Thunk::kCustomCall, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), call_target_(call_target), operand_slices_(std::move(operand_slices)), result_slices_(std::move(result_slices)), opaque_(std::move(opaque)) { - const HloInstruction* instr = hlo_instruction(); + const HloInstruction* instr = hlo_instruction_; CHECK_EQ(instr->operand_count(), operand_slices_.size()); for (int64 i = 0; i < instr->operand_count(); ++i) { const auto& s1 = operand_slices_[i].shape(); diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h index 72175daf3dd..31c03f5252f 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/custom_call_thunk.h @@ -46,6 +46,7 @@ class CustomCallThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; void* call_target_; std::vector> operand_slices_; ShapeTree result_slices_; diff --git a/tensorflow/compiler/xla/service/gpu/data/hlo_algorithm_blacklist.pbtxt b/tensorflow/compiler/xla/service/gpu/data/hlo_algorithm_denylist.pbtxt similarity index 100% rename from tensorflow/compiler/xla/service/gpu/data/hlo_algorithm_blacklist.pbtxt rename to tensorflow/compiler/xla/service/gpu/data/hlo_algorithm_denylist.pbtxt diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index eee0fc83481..3f000a2491d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -289,7 +289,7 @@ StatusOr GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, auto one_with_sign = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::copysign, {one, input}, {type}, b_); return FPCast(Select(FCmpULT(abs_value, max_value), fast_tanh, one_with_sign), - value->getType()); + value->getType(), "tanh"); } StatusOr GpuElementalIrEmitter::EmitComplexAbs( diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 7fc3bdd4436..ccd661d8ade 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -26,6 +26,7 @@ namespace gpu { ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit, std::unique_ptr body_thunk_sequence) : Thunk(Kind::kWhile, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), loop_limit_(loop_limit), body_thunk_sequence_(absl::make_unique( // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ @@ -33,11 +34,6 @@ ForThunk::ForThunk(ThunkInfo thunk_info, const int64 loop_limit, // this ForThunk, and shouldn't be profiled separately from it. ThunkInfo(), std::move(*body_thunk_sequence))) {} -void ForThunk::ComputeAnnotations() { - Thunk::ComputeAnnotations(); - body_thunk_sequence_->ComputeAnnotations(); -} - Status ForThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable, executor)); @@ -46,14 +42,14 @@ Status ForThunk::Initialize(const GpuExecutable& executable, Status ForThunk::ExecuteOnStream(const ExecuteParams& params) { VLOG(2) << "Executing ForThunk with " << loop_limit_ << " iters for " - << (hlo_instruction() ? hlo_instruction()->ToString() : ""); + << (hlo_instruction_ ? hlo_instruction_->ToString() : ""); auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); for (int64 i = 0; i < loop_limit_; ++i) { params.profiler->StartHloComputation(); // Invoke loop body thunk sequence. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); - params.profiler->FinishHloComputation(hlo_instruction()->while_body()); + params.profiler->FinishHloComputation(hlo_instruction_->while_body()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 77a89ea6023..b6ee950737e 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -36,12 +36,12 @@ class ForThunk : public Thunk { ForThunk(const ForThunk&) = delete; ForThunk& operator=(const ForThunk&) = delete; - void ComputeAnnotations() override; Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; const int64 loop_limit_; std::unique_ptr body_thunk_sequence_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 561dfbe3137..e55df0bb230 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -40,6 +40,7 @@ GemmThunk::GemmThunk(ThunkInfo thunk_info, bool implements_whole_instruction, const GemmBackendConfig &backend_config) : Thunk(Kind::kGemm, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), lhs_buffer_(lhs_buffer), rhs_buffer_(rhs_buffer), output_buffer_(output_buffer), @@ -51,11 +52,11 @@ Status GemmThunk::ExecuteOnStream(const ExecuteParams ¶ms) { return params.buffer_allocations->GetDeviceAddress(slice); }; - VLOG(3) << "Running GEMM thunk on instruction: " << hlo_instruction(); + VLOG(3) << "Running GEMM thunk on instruction: " << hlo_instruction_; se::DeviceMemoryBase lhs_data = get_device_address(lhs_buffer_); se::DeviceMemoryBase rhs_data = get_device_address(rhs_buffer_); se::DeviceMemoryBase output_data = get_device_address(output_buffer_); - return RunGemm(hlo_instruction(), backend_config_, lhs_data, rhs_data, + return RunGemm(hlo_instruction_, backend_config_, lhs_data, rhs_data, output_data, params.stream, implements_whole_instruction_, profile_index(), params.profiler); } @@ -82,24 +83,28 @@ static bool DoGemmWithAlgorithm( // Converts from an XLA PrimitiveType to a blas::ComputationType, which is // used to specify the precision with which matmul computations should be // performed, separately from the precision of the inputs and result. - se::blas::ComputationType computation_type = [&](PrimitiveType type) { - switch (type) { - case F16: - // Use F32 as computation type for F16 as we currently only implement - // the cuDNN pseudo half configuration for half precision. - return se::blas::ComputationType::kF32; - case F32: - return se::blas::ComputationType::kF32; - case F64: - return se::blas::ComputationType::kF64; - case C64: - return se::blas::ComputationType::kComplexF32; - case C128: - return se::blas::ComputationType::kComplexF64; - default: - LOG(FATAL) << "Unsupported type."; - } - }(type); + se::blas::ComputationType computation_type; + switch (type) { + case F16: + // Use F32 as computation type for F16 as we currently only implement + // the cuDNN pseudo half configuration for half precision. + computation_type = se::blas::ComputationType::kF32; + break; + case F32: + computation_type = se::blas::ComputationType::kF32; + break; + case F64: + computation_type = se::blas::ComputationType::kF64; + break; + case C64: + computation_type = se::blas::ComputationType::kComplexF32; + break; + case C128: + computation_type = se::blas::ComputationType::kComplexF64; + break; + default: + return false; + } se::DeviceMemory lhs_data(lhs_matrix.data); se::DeviceMemory rhs_data(rhs_matrix.data); @@ -296,7 +301,7 @@ Status RunGemm(const HloInstruction *gemm, stream, best_algorithm, /*output_profile_result=*/profile_result); default: - LOG(FATAL) << "Unsupported type."; + return false; } }(); diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 2bccb7b3572..1a51a7d4e0c 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -51,6 +51,7 @@ class GemmThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; const BufferAllocation::Slice lhs_buffer_; const BufferAllocation::Slice rhs_buffer_; const BufferAllocation::Slice output_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto index 35b5cfacb2d..563245da969 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto +++ b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto @@ -15,19 +15,19 @@ message ConvInstructionLog { repeated uint64 operand_addresses = 4; } -message BlacklistedAlgorithm { +message DenylistedAlgorithm { int64 id = 1; bool tensor_ops = 2; } -message AlgorithmBlacklistEntry { +message AlgorithmDenylistEntry { string hlo = 1; tensorflow.ComputeCapability cc = 2; tensorflow.CudnnVersion cudnn_version = 3; string blas_version = 5; - repeated BlacklistedAlgorithm algos = 4; + repeated DenylistedAlgorithm algos = 4; } -message AlgorithmBlacklist { - repeated AlgorithmBlacklistEntry entries = 1; +message AlgorithmDenylist { + repeated AlgorithmDenylistEntry entries = 1; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 3dcdb4c90eb..f5bf7476059 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -35,6 +35,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/comparison_expander.h" +#include "tensorflow/compiler/xla/service/conditional_canonicalizer.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_4d_expander.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" @@ -42,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" #include "tensorflow/compiler/xla/service/dynamic_padder.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" +#include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" @@ -138,6 +141,9 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass(); pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); + // Comparison total order expander + pipeline.AddPass(); + // Remove zero-sized HLO from the input so that other passes don't have to // handle it. pipeline.AddPass(); @@ -179,7 +185,7 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass( /*expansion_type=*/LogisticExpansionType::kExp); - + pipeline.AddPass(); pipeline.AddPass(); { @@ -189,11 +195,13 @@ Status GpuCompiler::OptimizeHloModule( /*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - pipeline.AddPass(); + pass.AddPass(); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. - pipeline.AddPass(); + pass.AddPass(); + + pass.AddPass(GatherExpander::kEliminateSimpleGathers); AlgebraicSimplifierOptions options; // When transposes appear in a fusion node, we can easily adjust the @@ -537,10 +545,10 @@ static Status CompileModuleToLlvmIrImpl( // computation. // * For each visit of these HloInstructions, either none or one Thunk // will be returned. - // * If there is a thunk returned, thunk->hlo_instruction() equals the + // * If there is a thunk returned, thunk->hlo_instruction_ equals the // input HloInstruction*. // * A returned thunk may contain other sub-thunks. A sub-thunk may or may - // not have an associated hlo_instruction(). + // not have an associated hlo_instruction_. TF_RET_CHECK(thunks->size() <= 1) << instruction->ToString(); if (!thunks->empty()) { auto thunk = std::move(thunks->front()); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 67255f02665..8fb741323f3 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h" -#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -438,10 +438,9 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( (void)blas->GetVersion(&blas_version); } - absl::Span blacklisted_algos = - GetBlacklistedConvAlgorithms(GetComputeCapability(stream_exec_), - GetCudnnVersion(stream_exec_), blas_version, - canonical_hlo); + absl::Span disabled_algos = GetDisabledConvAlgorithms( + GetComputeCapability(stream_exec_), GetCudnnVersion(stream_exec_), + blas_version, canonical_hlo); for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { XLA_SCOPED_LOGGING_TIMER_LEVEL( @@ -449,7 +448,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( AlgorithmToString(alg)), 2); - if (absl::c_linear_search(blacklisted_algos, alg)) { + if (absl::c_linear_search(disabled_algos, alg)) { LOG(INFO) << "Omitted potentially buggy algorithm " << AlgorithmToString(alg) << " for conv " << instr->ToString(); continue; @@ -503,7 +502,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( if (!input_output_allocator_redzone_clear || !scratch_allocator_redzone_clear) { - AlgorithmBlacklist proto; + AlgorithmDenylist proto; auto entry = proto.add_entries(); entry->set_hlo(canonical_hlo); *entry->mutable_cc() = GetComputeCapability(stream_exec_); @@ -513,13 +512,12 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( algo->set_id(alg.algo_id()); algo->set_tensor_ops(alg.tensor_ops_enabled()); - LOG(ERROR) - << "To blacklist this algorithm for this convolution, " - "copy-paste the following " - "proto to the blacklist file pointed by XLA_FLAGS " - "--xla_gpu_algorithm_blacklist_path=" - << GetDebugOptionsFromFlags().xla_gpu_algorithm_blacklist_path() - << " : " << proto.ShortDebugString(); + LOG(ERROR) << "To denylist this algorithm for this convolution, " + "copy-paste the following " + "proto to the denylist file pointed by XLA_FLAGS " + "--xla_gpu_algorithm_denylist_path=" + << GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path() + << " : " << proto.ShortDebugString(); continue; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index a6fc4686143..5cc5fa7d16d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -484,11 +484,12 @@ Status RunGpuConv(const HloCustomCallInstruction* conv, return RunGpuConvImpl(params, scratch_allocator, stream, options); default: - LOG(FATAL) << conv->ToString(); + return Unimplemented("Unimplemented convolution %s", + conv->ToString()); } } default: - LOG(FATAL) << conv->ToString(); + return Unimplemented("Unimplemented convolution %s", conv->ToString()); } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 89c5e123a48..726f1963545 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -71,7 +71,6 @@ GpuExecutable::GpuExecutable( CHECK(has_module() && assignment_); GpuDebugInfoManager::Get()->RegisterModule(module().name(), shared_module(), assignment_); - ComputeThunkAnnotations(); } GpuExecutable::~GpuExecutable() { @@ -93,12 +92,6 @@ GpuExecutable::~GpuExecutable() { } } -void GpuExecutable::ComputeThunkAnnotations() { - for (Thunk* thunk : thunk_schedule_->TotalOrder()) { - thunk->ComputeAnnotations(); - } -} - Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( const ServiceExecutableRunOptions* run_options) { se::Stream* main_stream = run_options->stream(); @@ -186,8 +179,8 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } - VLOG(2) << "Executing the thunk for " << thunk->name() << " on stream " - << stream_no; + VLOG(2) << "Executing the thunk for " << thunk->profile_annotation() + << " on stream " << stream_no; const GpuExecutableRunOptions* gpu_options = run_options->run_options().gpu_executable_run_options(); Thunk::ExecuteParams thunk_params{ @@ -487,6 +480,12 @@ StatusOr GpuExecutable::ExecuteAsyncOnStream( ExecutionInput& input = arguments[alias->parameter_number]; MaybeOwningDeviceMemory* maybe_owning_memory = input.MutableBuffer(alias->parameter_index); + if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) { + return InvalidArgument( + "An input was configured to be must-alias at " + "compile time but not donated at runtime: %s", + alias->ToString()); + } if (absl::optional owning = maybe_owning_memory->Release()) { // If the caller passes the ownership of the device memory, reuse it diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 0da446c9739..516fa9b269a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -115,9 +115,6 @@ class GpuExecutable : public Executable { StatusOr ResolveConstantGlobals( stream_executor::Stream* stream); - // Computes annotations for each thunk and store them in thunk_annotations_. - void ComputeThunkAnnotations(); - // GpuExecutable check with either AMD's ISA version, or Nvidia's major minor // version for compute capability, depending on the hardware. Status CheckCompatibilityWithServiceExecutableRunOptions( diff --git a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.cc b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.cc similarity index 81% rename from tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.cc rename to tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.cc index 601c805ce16..4a0075f2870 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h" #include @@ -24,7 +24,7 @@ limitations under the License. namespace xla { namespace gpu { -constexpr char kDefaultBlacklist[] = R"pb( +constexpr char kDefaultDenylist[] = R"pb( entries { hlo: "(f32[4,32,32,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[4,32,32,32]{2,1,3,0}, f32[5,5,32,32]{1,0,2,3}), window={size=5x5 pad=2_2x2_2}, dim_labels=b01f_01io->b01f, custom_call_target=\"__cudnn$convForward\", backend_config=\"{conv_result_scale:1}\"" cc { major: 7 } @@ -41,28 +41,26 @@ constexpr char kDefaultBlacklist[] = R"pb( } )pb"; -absl::Span -GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc, - tensorflow::CudnnVersion cudnn_version, - const std::string& blas_version, - const std::string& hlo) { +absl::Span GetDisabledConvAlgorithms( + tensorflow::ComputeCapability cc, tensorflow::CudnnVersion cudnn_version, + const std::string& blas_version, const std::string& hlo) { // Key is the tuple of canonicalized hlo, compute capability major/minor, // cudnn version major/minor/patch, blas version. using MapType = absl::flat_hash_map< std::tuple, std::vector>; - static MapType* blacklist = [] { + static MapType* denylist = [] { MapType* list = new MapType(); - AlgorithmBlacklist proto; + AlgorithmDenylist proto; std::string file_path = - GetDebugOptionsFromFlags().xla_gpu_algorithm_blacklist_path(); + GetDebugOptionsFromFlags().xla_gpu_algorithm_denylist_path(); if (!file_path.empty()) { TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_path, &proto)); } else { CHECK(tensorflow::protobuf::TextFormat::ParseFromString( - std::string(kDefaultBlacklist), &proto)); + std::string(kDefaultDenylist), &proto)); } for (const auto& entry : proto.entries()) { for (const auto& algo : entry.algos()) { @@ -77,10 +75,10 @@ GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc, return list; }(); - auto iter = blacklist->find(std::make_tuple( + auto iter = denylist->find(std::make_tuple( hlo, cc.major(), cc.minor(), cudnn_version.major(), cudnn_version.minor(), cudnn_version.patch(), std::string(blas_version))); - if (iter != blacklist->end()) { + if (iter != denylist->end()) { return iter->second; } return {}; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h similarity index 62% rename from tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h rename to tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h index c1955a452aa..73d1219c1ab 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_ #include @@ -24,13 +24,11 @@ limitations under the License. namespace xla { namespace gpu { -absl::Span -GetBlacklistedConvAlgorithms(tensorflow::ComputeCapability cc, - tensorflow::CudnnVersion cudnn_version, - const std::string& blas_version, - const std::string& hlo); +absl::Span GetDisabledConvAlgorithms( + tensorflow::ComputeCapability cc, tensorflow::CudnnVersion cudnn_version, + const std::string& blas_version, const std::string& hlo); } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_BLACKLIST_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_ALGORITHM_DENYLIST_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist_test.cc similarity index 84% rename from tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist_test.cc rename to tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist_test.cc index bc24f486668..ab1cc1c79de 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_blacklist.h" +#include "tensorflow/compiler/xla/service/gpu/hlo_algorithm_denylist.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" @@ -26,22 +26,22 @@ namespace xla { namespace gpu { namespace { -class BlacklistTest : public testing::Test { +class DenylistTest : public testing::Test { protected: - BlacklistTest() { + DenylistTest() { tensorflow::setenv( "XLA_FLAGS", absl::StrCat( - "--xla_gpu_algorithm_blacklist_path=", + "--xla_gpu_algorithm_denylist_path=", tensorflow::GetDataDependencyFilepath(tensorflow::io::JoinPath( "tensorflow", "compiler", "xla", "service", "gpu", "data", - "hlo_algorithm_blacklist.pbtxt"))) + "hlo_algorithm_denylist.pbtxt"))) .data(), 0); } }; -TEST_F(BlacklistTest, DefaultTest) { +TEST_F(DenylistTest, DefaultTest) { tensorflow::ComputeCapability cc; cc.set_major(7); cc.set_minor(0); @@ -49,7 +49,7 @@ TEST_F(BlacklistTest, DefaultTest) { cudnn_version.set_major(7); cudnn_version.set_minor(6); cudnn_version.set_patch(2); - auto list = GetBlacklistedConvAlgorithms( + auto list = GetDisabledConvAlgorithms( cc, cudnn_version, /*blas_version=*/"9000", R"((f16[256,112,112,64]{3,2,1,0}, u8[0]{0}) custom-call(f16[256,224,224,4]{3,2,1,0}, f16[7,7,4,64]{2,1,0,3}), window={size=7x7 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", backend_config="{conv_result_scale:1}")"); ASSERT_EQ(4, list.size()); @@ -59,7 +59,7 @@ TEST_F(BlacklistTest, DefaultTest) { EXPECT_EQ(stream_executor::dnn::AlgorithmDesc(1, true), list[3]); } -TEST_F(BlacklistTest, NegativeTest) { +TEST_F(DenylistTest, NegativeTest) { tensorflow::ComputeCapability cc; cc.set_major(7); cc.set_minor(0); @@ -68,7 +68,7 @@ TEST_F(BlacklistTest, NegativeTest) { cudnn_version.set_minor(6); cudnn_version.set_minor(2); auto list = - GetBlacklistedConvAlgorithms(cc, cudnn_version, "9000", R"(invalid hlo)"); + GetDisabledConvAlgorithms(cc, cudnn_version, "9000", R"(invalid hlo)"); ASSERT_EQ(0, list.size()); } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 23b29df6ec8..5d38d1b727c 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -38,13 +38,15 @@ using absl::StrCat; void HloToIrBindings::EmitBasePointersForHlos( absl::Span io_hlos, absl::Span non_io_hlos) { + CHECK(is_nested_); + // I/O HLOs are bound to the arguments of the current IR function, // *excluding* the output argument, which is added to non-I/O HLOs. // I.e., // - // void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg, temp_buffer_base) { + // void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg); llvm::Function* function = b_->GetInsertBlock()->getParent(); - CHECK_EQ(io_hlos.size() + 2, function->arg_size()); + CHECK_EQ(io_hlos.size() + 1, function->arg_size()); // An HLO can have duplicated operands. This data structure remembers which // operand HLOs are already bound to avoid rebinding the same HLO. @@ -55,11 +57,7 @@ void HloToIrBindings::EmitBasePointersForHlos( !absl::c_count(non_io_hlos, io_hlo)) << "IO HLOs and non-IO HLOs should be disjoint"; if (!already_bound_for_this_function.contains(io_hlo)) { - if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) { - BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter)); - } else { - BindHloToIrValue(*io_hlo, &*arg_iter); - } + BindHloToIrValue(*io_hlo, &*arg_iter); already_bound_for_this_function.insert(io_hlo); } ++arg_iter; @@ -69,9 +67,6 @@ void HloToIrBindings::EmitBasePointersForHlos( arg_iter->setName("output_arg"); ++arg_iter; - temp_buffer_base_ = &*arg_iter; - temp_buffer_base_->setName("temp_buffer"); - for (const HloInstruction* non_io_hlo : non_io_hlos) { if (already_bound_for_this_function.contains(non_io_hlo)) { continue; @@ -79,62 +74,23 @@ void HloToIrBindings::EmitBasePointersForHlos( already_bound_for_this_function.insert(non_io_hlo); if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) { - if (!is_nested_) { - // Lookup allocation GetTupleElement operand. - const BufferAllocation::Slice slice = - buffer_assignment_ - ->GetUniqueTopLevelSlice(non_io_hlo->LatestNonGteAncestor()) - .ConsumeValueOrDie(); - // We are not in a nested context, so check non-thread-local allocation. - CHECK(!slice.allocation()->is_thread_local()); - const int64 offset = slice.offset(); - CHECK_NE(nullptr, temp_buffer_base_); - // Emit IR for GetTupleElement instruction and bind to emitted value. - llvm::Value* base_ptr = - b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)); - BindHloToIrValue(*non_io_hlo, - EmitGetTupleElement(non_io_hlo, base_ptr)); - } - continue; - } - - if (!buffer_assignment_->HasTopLevelAllocation(non_io_hlo)) { continue; } ShapeUtil::ForEachSubshape( non_io_hlo->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { - // A non-IO HLO with a buffer is bound to - // (1) an alloca if it is thread-local, or - // (2) an internal pointer in temp_buffer_base according to its - // offset. - auto slice_result = - buffer_assignment_->GetUniqueSlice(non_io_hlo, index); - if (!slice_result.ok()) { - return; - } - const BufferAllocation::Slice slice = - slice_result.ConsumeValueOrDie(); - if (slice.allocation()->is_thread_local()) { + if (non_io_hlo->opcode() == HloOpcode::kConstant) { + llvm::Value* global_for_constant = module_->getGlobalVariable( + llvm_ir::ConstantHloToGlobalName(*non_io_hlo)); + BindHloToIrValue(*non_io_hlo, global_for_constant); + } else { llvm::Type* pointee_type = llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); BindHloToIrValue(*non_io_hlo, llvm_ir::EmitAllocaAtFunctionEntry( pointee_type, /*name=*/"", b_), index); - } else if (slice.allocation()->is_constant()) { - llvm::Value* global_for_constant = module_->getGlobalVariable( - llvm_ir::ConstantBufferAllocationToGlobalName( - *slice.allocation())); - BindHloToIrValue(*non_io_hlo, global_for_constant); - } else { - const int64 offset = slice.offset(); - CHECK_NE(nullptr, temp_buffer_base_); - BindHloToIrValue( - *non_io_hlo, - b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)), - index); } }); } @@ -231,14 +187,14 @@ llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, << " of " << hlo.ToString(); llvm_ir::IrArray ir_array(base_ptr, ShapeUtil::GetSubshape(hlo.shape(), shape_index)); - alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index); // The GPU backend emits one kernel per top-level HLO, and LLVM views // execution of one kernel as the "whole program" executed on the GPU. // Therefore if hlo's output buffer is not modified within consumer, and if // consumer runs hlo only once (so that it doesn't create two different // outputs), then we can mark ir_array as invariant over the whole program. - if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) { + if (!is_nested_ && + BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) { VLOG(2) << "Marking " << hlo.name() << " as invariant within " << consumer.name(); ir_array.MarkInvariantOverWholeProgram(&module_->getContext()); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index f57b594e9c1..5eef6727801 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" namespace xla { @@ -42,8 +41,7 @@ class HloToIrBindings { : buffer_assignment_(buffer_assignment), is_nested_(is_nested), b_(b), - module_(llvm_module), - alias_analysis_(module, *buffer_assignment_, &b_->getContext()) {} + module_(llvm_module) {} void EmitBasePointersForHlos( absl::Span io_hlos, @@ -116,8 +114,6 @@ class HloToIrBindings { // The address of the memory block that contains all temporary buffers. llvm::Value* temp_buffer_base_ = nullptr; - - llvm_ir::AliasAnalysis alias_analysis_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 43cc5f5a2ae..5fe459a70bc 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -25,13 +25,15 @@ namespace gpu { InfeedThunk::InfeedThunk( ThunkInfo thunk_info, const ShapeTree& infeed_slices) - : Thunk(Kind::kInfeed, thunk_info), infeed_slices_(infeed_slices) {} + : Thunk(Kind::kInfeed, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), + infeed_slices_(infeed_slices) {} Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) { auto& stream = *params.stream; auto& buffer_allocations = *params.buffer_allocations; - VLOG(2) << "Infeeding to GPU: " << hlo_instruction()->ToString(); + VLOG(2) << "Infeeding to GPU: " << hlo_instruction_->ToString(); auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index ec33235c466..ab410661ba1 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -43,6 +43,7 @@ class InfeedThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; const ShapeTree infeed_slices_; }; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index a0580e2ab04..b994ead17ca 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -29,12 +29,27 @@ limitations under the License. namespace xla { namespace gpu { +namespace { +bool ElementIsF32OrF16(const Shape& shape) { + PrimitiveType type = shape.element_type(); + return type == F32 || type == F16; +} +} // namespace + /*static*/ bool GpuInstructionFusion::IsExpensive( const HloInstruction& instruction) { - // We say that floating-point division is cheap on the GPU. - if (instruction.opcode() == HloOpcode::kDivide && - ShapeUtil::ElementIsFloating(instruction.shape())) { - return false; + // We say that some floating-point math ops are cheap on the GPU. Unlike other + // intrinsics that can be expanded into many instructions, Div and Rsqrt are + // lowered into single hardware instructions. + switch (instruction.opcode()) { + case HloOpcode::kDivide: + case HloOpcode::kRsqrt: + if (ElementIsF32OrF16(instruction.shape())) { + return false; + } + break; + default: + break; } return InstructionFusion::IsExpensive(instruction); } diff --git a/tensorflow/compiler/xla/service/gpu/ir/dialect_registration.cc b/tensorflow/compiler/xla/service/gpu/ir/dialect_registration.cc new file mode 100644 index 00000000000..2e3461951d8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir/dialect_registration.cc @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h" + +// Static initialization for GPU thunks op registration. +static mlir::DialectRegistration + xla_thunks_ops; diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc new file mode 100644 index 00000000000..154612824ef --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the Thunk dialect. + +#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" + +namespace mlir { +#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_structs.cc.inc" +namespace xla_thunks { + +XLAThunksDialect::XLAThunksDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc.inc" + >(); +} + +#define GET_OP_CLASSES +#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.cc.inc" + +} // namespace xla_thunks +} // namespace mlir diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h new file mode 100644 index 00000000000..ede9adb9ab1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_XLA_THUNKS_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_XLA_THUNKS_OPS_H_ + +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project + +namespace mlir { +class OpBuilder; + +#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_structs.h.inc" + +namespace xla_thunks { + +class XLAThunksDialect : public Dialect { + public: + explicit XLAThunksDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "xla_thunks"; } +}; + +#define GET_OP_CLASSES +#include "tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.h.inc" + +} // namespace xla_thunks +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_XLA_THUNKS_OPS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td new file mode 100644 index 00000000000..38602550864 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir/xla_thunks_ops.td @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Operation definition file for GPU thunks. + +#ifndef XLA_THUNKS_OPS +#define XLA_THUNKS_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/IR/OpBase.td" + +class LLVMPointerTo + : ContainerType().isPointerTy()">, + "$_self.cast<::mlir::LLVM::LLVMType>().getPointerElementTy()", + "LLVM pointer">; + +def XLAThunks_Dialect : Dialect { + let name = "xla_thunks"; + let cppNamespace = "xla_thunks"; +} + +class ThunkOp traits = []> : + Op; + +def AllocationSlice : StructAttr<"AllocationSlice", XLAThunks_Dialect, [ + StructFieldAttr<"allocation_index", I64Attr>, + StructFieldAttr<"offset", I64Attr>, + StructFieldAttr<"size", I64Attr>, + ]> { + let description = "Defines a slice of an allocation for XLA thunk ops"; +} + +def MemzeroThunkOp : ThunkOp<"execute_memzero_thunk"> { + let arguments = (ins + LLVMPointerTo>:$execute_params, + AllocationSlice:$allocation_slice + ); + let results = (outs + I<1>:$ok, + LLVMPointerTo>:$error_message + ); +} + +#endif // XLA_THUNKS_OPS diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 04e24733971..31203b9c5f0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -192,9 +192,6 @@ Status IrEmitter::EmitCallToNestedComputation( llvm::Value* casted_output = AddrCastToDefault(output, b_); arguments.push_back(casted_output); - // It is not required to do address space cast because TempBufferBase - // is always in addrspace 0. - arguments.push_back(bindings_.GetTempBufferBase()); Call(emitted_function, arguments); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 72f48c49096..e96c5f05e60 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -67,8 +67,6 @@ Status IrEmitterNested::CodegenNestedComputation() { root_shape, ir_emitter_context_->llvm_module()->getDataLayout()); argument_dereferenceable_bytes.push_back(root_size); } - // The base pointer of the memory block for all pre-allocated temp buffers. - argument_types.push_back(b_.getInt8PtrTy()); llvm::FunctionType* function_type = llvm::FunctionType::get(b_.getVoidTy(), argument_types, false); @@ -119,8 +117,8 @@ Status IrEmitterNested::CodegenNestedComputation() { llvm::Value* root_value = bindings_.GetBasePointer(*root_instruction); const Shape& return_shape = root_instruction->shape(); - // Second last argument is the out parameter. - llvm::Argument* out_parameter = std::prev(function->arg_end(), 2); + // Last argument is the out parameter. + llvm::Argument* out_parameter = std::prev(function->arg_end(), 1); if (ShapeUtil::IsScalar(return_shape)) { llvm::Value* ret_value = Load(root_value, "load_ret_value"); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index a232bf7fce5..61b78b6004d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "llvm/ADT/StringRef.h" @@ -1284,6 +1285,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { if (destination_buffer != source_address) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. + VLOG(2) << sort->name() << " requires initial D2D copy for operand " << i; thunks.push_back(absl::make_unique( Thunk::ThunkInfo(), /*source_address=*/source_address, @@ -1294,6 +1296,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); + VLOG(2) << sort->name() << " requires " << num_stages << " stages."; CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); @@ -1368,11 +1371,27 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { ir_emitter_context_->gpu_device_info().threads_per_block_limit || total_shared_memory_needed > ir_emitter_context_->gpu_device_info().shared_memory_per_block; + VLOG(2) << absl::StreamFormat( + "%s %s use tiling. No tiling if any of the following is true: " + "kTileSize=%d < 128, " + "kThreadsPerBlock=%d > threads_per_block_limit=%d, " + "total_shared_memory_needed=%d > shared_memory_per_block=%d", + sort->name(), (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, + ir_emitter_context_->gpu_device_info().threads_per_block_limit, + total_shared_memory_needed, + ir_emitter_context_->gpu_device_info().shared_memory_per_block); uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); + VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block", + sort->name(), num_blocks, kThreadsPerBlock); auto emit_kernel = [&](absl::Span xor_masks) { + VLOG(2) << absl::StreamFormat( + "%s uses kernel for xor masks [%s]", sort->name(), + absl::StrJoin(xor_masks, ", ", [](std::string* out, int64 xor_mask) { + absl::StrAppendFormat(out, "0x%x", xor_mask); + })); thunks.push_back( BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); LaunchDimensions launch_dimensions = xor_masks.size() > 1 @@ -1421,6 +1440,9 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { if (!xor_masks.empty()) { TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); } + VLOG(2) << absl::StreamFormat( + "%s requires %d thunks (including any D2D copies)", sort->name(), + thunks.size()); AddThunkToThunkSequence(absl::make_unique( GetThunkInfo(sort), std::move(thunks))); @@ -1747,6 +1769,25 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( auto buffers_it = non_constant_buffers.begin(); for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { kernel_args[*buffers_it] = arg_it; + + // Annotate all allocations with LLVM's `noalias`. + // There are three kinds of allocations: + // * Read-only allocations, aka input parameters that are not aliased with + // outputs. + // * Read-write allocations, including all output buffers, some of which + // may alias with input HLO parameters, but aliased HLO buffers are always + // assigned with the same allocation. + // * The temp buffer. + // + // Read-only allocations may overlap with each other, but since they are + // not mutated, they can always be annotated with `noalias` per LLVM + // semantics. + // + // Read-write allocations and the temp buffer don't overlap with any + // allocations, therefore they can also be annotated with `noalias`. + kernel->addParamAttr( + arg_it->getArgNo(), + llvm::Attribute::get(arg_it->getContext(), llvm::Attribute::NoAlias)); } } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index d2126a8d17d..1228a1b4823 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -83,10 +83,10 @@ const int kDefaultInlineThreshold = 1100; static string GetSmName(std::pair compute_capability) { int compute_capability_version = compute_capability.first * 10 + compute_capability.second; - int sm_version = 35; + int sm_version = 30; // If the current compute capability isn't known, fallback to the // most recent version before it. - for (int v : {75, 72, 70, 62, 61, 60, 53, 52, 50, 37, 35}) { + for (int v : {75, 72, 70, 62, 61, 60, 53, 52, 50, 37, 35, 32, 30}) { if (v <= compute_capability_version) { sm_version = v; break; @@ -630,8 +630,10 @@ StatusOr> EmitModuleToHsaco( // Locate lld. // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after // ROCm-Device-Libs PR. - std::string lld_path = tensorflow::io::JoinPath("/opt/rocm", "hcc/bin"); - auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path}); + std::string lld_path_1 = tensorflow::io::JoinPath("/opt/rocm", "hcc/bin"); + std::string lld_path_2 = tensorflow::io::JoinPath("/opt/rocm", "llvm/bin"); + auto lld_program = + llvm::sys::findProgramByName("ld.lld", {lld_path_1, lld_path_2}); if (!lld_program) { return xla::InternalError("unable to find ld.lld in PATH: %s", lld_program.getError().message()); diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 755413beeee..25ab9a7ce6e 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -544,10 +544,11 @@ NcclAllReduceThunk::NcclAllReduceThunk( ThunkInfo thunk_info, int64 replica_count, std::vector buffers) : Thunk(Thunk::kNcclAllReduce, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), replica_count_(replica_count), buffers_(std::move(buffers)), aux_data_(absl::make_unique()) { - CHECK_EQ(hlo_instruction()->operand_count(), buffers_.size()); + CHECK_EQ(hlo_instruction_->operand_count(), buffers_.size()); } // Figures out which devices (named by their replica-ids) are participating in @@ -557,7 +558,7 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); - auto* instr = Cast(hlo_instruction()); + auto* instr = Cast(hlo_instruction_); int64 local_device_ordinal = params.stream->parent()->device_ordinal(); GlobalDeviceId global_device_id; if (params.gpu_global_device_ids) { @@ -606,7 +607,7 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { // Find or create the rendezvous for this collective operation. RendezvousKey rendezvous_key = RendezvousKey::FromInstruction( - params.run_id, global_devices, local_devices.size(), hlo_instruction()); + params.run_id, global_devices, local_devices.size(), hlo_instruction_); if (VLOG_IS_ON(2)) { std::vector local_participants; @@ -633,13 +634,12 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { pbuffer.destination_data = params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer); pbuffer.primitive_type = - hlo_instruction()->operand(i)->shape().element_type(); + hlo_instruction_->operand(i)->shape().element_type(); participant.buffers.push_back(pbuffer); } participant.local_devices = std::move(local_devices); participant.nccl_unique_id_callback = params.nccl_unique_id_callback; - auto reduction_kind = - MatchReductionComputation(hlo_instruction()->to_apply()); + auto reduction_kind = MatchReductionComputation(hlo_instruction_->to_apply()); CHECK(reduction_kind.has_value()); participant.reduction_kind = *reduction_kind; diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 1df4f0805a6..cbd4fd3aa51 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -73,6 +73,7 @@ class NcclAllReduceThunk : public Thunk { // build, and we don't want to expose *that* mess in the header.) struct AuxData; + const HloInstruction* hlo_instruction_; const int64 replica_count_; const std::vector buffers_; std::unique_ptr aux_data_; diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 104366fd78c..83066a4addf 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -26,13 +26,14 @@ namespace gpu { OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, ShapeTree outfeed_slices) : Thunk(Kind::kOutfeed, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), outfeed_slices_(std::move(outfeed_slices)) {} Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { auto& stream = *params.stream; auto& buffer_allocations = *params.buffer_allocations; - VLOG(2) << "Outfeeding from GPU: " << hlo_instruction()->ToString(); + VLOG(2) << "Outfeeding from GPU: " << hlo_instruction_->ToString(); auto op_profiler = params.profiler->MakeScopedInstructionProfiler(profile_index()); @@ -41,13 +42,13 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { outfeed_manager->BlockingGetNextDestination(); // Nothing to be done for empty tuples. - if (ShapeUtil::IsEmptyTuple(hlo_instruction()->operand(0)->shape())) { + if (ShapeUtil::IsEmptyTuple(hlo_instruction_->operand(0)->shape())) { return Status::OK(); } - CHECK(ShapeUtil::Compatible(hlo_instruction()->operand(0)->shape(), + CHECK(ShapeUtil::Compatible(hlo_instruction_->operand(0)->shape(), outfeed_buffers->shape())) << "XLA program outfeed request of shape " - << hlo_instruction()->operand(0)->shape().ToString() + << hlo_instruction_->operand(0)->shape().ToString() << " did not match the runtime's outfeed buffer of shape " << outfeed_buffers->shape().ToString(); diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h index e99174e3c6c..9174e605783 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.h @@ -41,6 +41,7 @@ class OutfeedThunk : public Thunk { Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; const ShapeTree outfeed_slices_; }; diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc index 15cf2493549..903acf4f57d 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc @@ -28,12 +28,6 @@ SequentialThunk::SequentialThunk(ThunkInfo thunk_info, std::vector> thunks) : Thunk(Kind::kSequential, thunk_info), thunks_(std::move(thunks)) {} -void SequentialThunk::ComputeAnnotations() { - for (const auto& thunk : thunks_) { - thunk->ComputeAnnotations(); - } -} - Status SequentialThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { for (auto& thunk : thunks_) { diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h index 127c5bcf734..455ee60fa5c 100644 --- a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h @@ -39,7 +39,6 @@ class SequentialThunk : public Thunk { const std::vector>& thunks() const { return thunks_; } - void ComputeAnnotations() override; Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const ExecuteParams& params) override; diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index a23c14017a4..a2bddd2d0d7 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -479,7 +479,10 @@ glob_lit_tests( "no_pip", ], driver = "@llvm-project//mlir:run_lit.sh", - test_file_exts = ["hlo"], + test_file_exts = [ + "hlo", + "mlir", + ], ) # Bundle together all of the test utilities that are used by tests. @@ -487,7 +490,17 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/xla/service/gpu/tests:hlo_to_llvm_ir", + ":hlo_to_llvm_ir", + ":xla-thunks-opt", "@llvm-project//llvm:FileCheck", ], ) + +# Binary with only the thunks dialect registered, for testing purposes. +tf_cc_binary( + name = "xla-thunks-opt", + deps = [ + "//tensorflow/compiler/mlir:tf_mlir_opt_main", + "//tensorflow/compiler/xla/service/gpu:xla_thunks_dialect_registration", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/tests/execute_memzero_thunk.mlir b/tensorflow/compiler/xla/service/gpu/tests/execute_memzero_thunk.mlir new file mode 100644 index 00000000000..82f3f06db5c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/execute_memzero_thunk.mlir @@ -0,0 +1,15 @@ +// RUN: xla-thunks-opt %s | FileCheck --color --dump-input=fail %s + +func @main( %execute_params: !llvm.ptr ) { + // CHECK: "xla_thunks.execute_memzero_thunk" + // CHECK-SAME: {allocation_index = 0 : i64, offset = 128 : i64, size = 1024 : i64} + // CHECK-SAME: (!llvm.ptr) -> (i1, !llvm.ptr) + %ok, %error_message = + "xla_thunks.execute_memzero_thunk"( %execute_params ) + { allocation_slice = { allocation_index = 0 + , offset = 128 + , size = 1024 } } + : (!llvm.ptr) -> (i1, !llvm.ptr) + return +} + diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc index 914b81c632f..3ebac925886 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_alignment_test.cc @@ -45,7 +45,7 @@ ENTRY main { )"; CompileAndVerifyIr(hlo_string, R"( -CHECK: @fusion(i8* align 64 dereferenceable(600) %alloc0, i8* align 16 dereferenceable(400) %alloc1, i8* align 64 dereferenceable(864) %temp_buf) +CHECK: @fusion(i8* noalias align 64 dereferenceable(600) %alloc0, i8* noalias align 16 dereferenceable(400) %alloc1, i8* noalias align 64 dereferenceable(864) %temp_buf) )"); } diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index 38ff2da7161..8ec00d73711 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -51,16 +51,9 @@ TEST_F(GpuNoAliasTest, Concat) { hlo_module->AddEntryComputation(std::move(computation)); CompileAndVerifyIr(std::move(hlo_module), - R"( -; CHECK: %[[x_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %x{{.*}}, i32 0 -; CHECK: load float, float* %[[x_gep]], {{.*}}, !noalias ![[param_noalias:.*]] -; CHECK: %[[y_gep:.*]] = getelementptr inbounds [2 x [2 x float]], [2 x [2 x float]]* %y{{.*}}, i32 0 -; CHECK: load float, float* %[[y_gep]], {{.*}}, !noalias ![[param_noalias]] -; CHECK: %[[result_ptr:.*]] = bitcast [2 x [6 x float]]* %fusion{{.*}} to float* -; CHECK: %[[result_gep:.*]] = getelementptr inbounds float, float* %[[result_ptr]] -; CHECK: store float {{.*}}, float* %[[result_gep]], align 4, !alias.scope ![[param_noalias]] -; CHECK: ![[param_noalias]] = !{![[retval_buffer:.*]]} - )", + R"(CHECK-LABEL: define{{.*}}void @fusion + CHECK-SAME: i8* noalias align {{[0-9]*}} dereferenceable({{[0-9]*}}) %[[OUTPUT_ALLOC:[a-z0-9]*]] + CHECK: %fusion.raw = {{.*}} %[[OUTPUT_ALLOC]])", /*match_optimized_ir=*/false); } diff --git a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo index 796c0adadd2..c9e7daeb3bc 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo @@ -1,6 +1,6 @@ // RUN: hlo_to_llvm_ir %s | FileCheck %s -// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) { +// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) { // CHECK: entry: // CHECK: %[[VAL_32:.*]] = alloca i32, align 4 // CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 @@ -26,7 +26,7 @@ // CHECK: ret void // CHECK: scatter_TensorFlowScatterV1.in_bounds-true: ; preds = %[[VAL_24]] // CHECK: %[[VAL_25:.*]] = getelementptr inbounds [2 x i32], [2 x i32]* %[[VAL_8]], i32 0, i32 %[[VAL_19]] -// CHECK: %[[VAL_26:.*]] = load i32, i32* %[[VAL_25]], align 4, !invariant.load !4, !noalias !5 +// CHECK: %[[VAL_26:.*]] = load i32, i32* %[[VAL_25]], align 4, !invariant.load !4 // CHECK: %[[VAL_27:.*]] = add i32 0, %[[VAL_26]] // CHECK: %[[VAL_28:.*]] = icmp ult i32 %[[VAL_26]], 3 // CHECK: %[[VAL_29:.*]] = and i1 true, %[[VAL_28]] @@ -37,7 +37,7 @@ // CHECK: %[[VAL_31:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_18]] // CHECK: %[[VAL_33:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_11]] to i32* // CHECK: %[[VAL_34:.*]] = getelementptr inbounds i32, i32* %[[VAL_33]], i32 %[[VAL_15]] -// CHECK: %[[VAL_35:.*]] = load i32, i32* %[[VAL_34]], align 4, !invariant.load !4, !noalias !5 +// CHECK: %[[VAL_35:.*]] = load i32, i32* %[[VAL_34]], align 4, !invariant.load !4 // CHECK: store i32 %[[VAL_35]], i32* %[[VAL_32]], align 4 // CHECK: %[[VAL_36:.*]] = load i32, i32* %[[VAL_32]], align 4 // CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4 @@ -48,9 +48,6 @@ // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{i32 0, i32 6} // CHECK: !4 = !{} -// CHECK: !5 = !{!6} -// CHECK: !6 = !{!"buffer: {index:0, offset:0, size:36}", !7} -// CHECK: !7 = !{!"XLA global AA domain"} HloModule TensorFlowScatterV1 @@ -75,7 +72,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* align 64 dereferenceable(4) %alloc0, i8* align 16 dereferenceable(4) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 %alloc3) { +// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) { // CHECK: entry: // CHECK: %[[VAL_60:.*]] = alloca i32, align 4 // CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0 @@ -101,7 +98,7 @@ ENTRY main { // CHECK: scatter.in_bounds-after: ; preds = %[[VAL_59]], %[[VAL_55]] // CHECK: br label %[[VAL_56]] // CHECK: scatter.in_bounds-true: ; preds = %[[VAL_55]] -// CHECK: %[[VAL_61:.*]] = load i32, i32* %[[VAL_48]], align 4, !invariant.load !3, !noalias !4 +// CHECK: %[[VAL_61:.*]] = load i32, i32* %[[VAL_48]], align 4, !invariant.load !3 // CHECK: store i32 %[[VAL_61]], i32* %[[VAL_60]], align 4 // CHECK: %[[VAL_62:.*]] = load i32, i32* %[[VAL_60]], align 4 // CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4 @@ -111,9 +108,6 @@ ENTRY main { // CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{} -// CHECK: !4 = !{!5} -// CHECK: !5 = !{!"buffer: {index:0, offset:0, size:4}", !6} -// CHECK: !6 = !{!"XLA global AA domain"} HloModule ScatterIntoScalar @@ -137,7 +131,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* align 64 dereferenceable(36) %alloc0, i8* align 16 dereferenceable(36) %alloc1, i8* align 16 dereferenceable(24) %alloc2, i8* align 16 dereferenceable(8) %alloc3) { +// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) { // CHECK: %[[VAL_63:.*]] = alloca i32, align 4 // CHECK: %[[VAL_64:.*]] = alloca i32, align 4 // CHECK: %[[VAL_98:.*]] = alloca i32, align 4 @@ -164,7 +158,7 @@ ENTRY main { // CHECK: ret void // CHECK: scatter_TensorFlowScatter_Mul.in_bounds-true: ; preds = %[[VAL_89]] // CHECK: %[[VAL_90:.*]] = getelementptr inbounds [2 x i32], [2 x i32]* %[[VAL_73]], i32 0, i32 %[[VAL_84]] -// CHECK: %[[VAL_91:.*]] = load i32, i32* %[[VAL_90]], align 4, !invariant.load !4, !noalias !5 +// CHECK: %[[VAL_91:.*]] = load i32, i32* %[[VAL_90]], align 4, !invariant.load !4 // CHECK: %[[VAL_92:.*]] = add i32 0, %[[VAL_91]] // CHECK: %[[VAL_93:.*]] = icmp ult i32 %[[VAL_91]], 3 // CHECK: %[[VAL_94:.*]] = and i1 true, %[[VAL_93]] @@ -175,7 +169,7 @@ ENTRY main { // CHECK: %[[VAL_97:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_67]], i32 0, i32 %[[VAL_92]], i32 %[[VAL_83]] // CHECK: %[[VAL_99:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_76]] to i32* // CHECK: %[[VAL_100:.*]] = getelementptr inbounds i32, i32* %[[VAL_99]], i32 %[[VAL_80]] -// CHECK: %[[VAL_101:.*]] = load i32, i32* %[[VAL_100]], align 4, !invariant.load !4, !noalias !5 +// CHECK: %[[VAL_101:.*]] = load i32, i32* %[[VAL_100]], align 4, !invariant.load !4 // CHECK: store i32 %[[VAL_101]], i32* %[[VAL_98]], align 4 // CHECK: %[[VAL_102:.*]] = load i32, i32* %[[VAL_98]], align 4 // CHECK: %[[VAL_103:.*]] = load i32, i32* %[[VAL_97]], align 4 @@ -186,7 +180,7 @@ ENTRY main { // CHECK: atomic_op_loop_body: ; preds = %[[VAL_104]], %[[VAL_95]] // CHECK: %[[VAL_105:.*]] = load i32, i32* %[[VAL_64]], align 4 // CHECK: store i32 %[[VAL_105]], i32* %[[VAL_63]], align 4 -// CHECK: call void @mul_s32(i32* %[[VAL_63]], i32* %[[VAL_98]], i32* %[[VAL_63]], i8* null) +// CHECK: call void @mul_s32(i32* %[[VAL_63]], i32* %[[VAL_98]], i32* %[[VAL_63]]) // CHECK: %[[VAL_106:.*]] = load i32, i32* %[[VAL_63]], align 4 // CHECK: %[[VAL_107:.*]] = cmpxchg i32* %[[VAL_97]], i32 %[[VAL_105]], i32 %[[VAL_106]] seq_cst seq_cst // CHECK: %[[VAL_108:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 0 @@ -199,15 +193,6 @@ ENTRY main { // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{i32 0, i32 6} // CHECK: !4 = !{} -// CHECK: !5 = !{!6} -// CHECK: !6 = !{!"buffer: {index:0, offset:0, size:36}", !7} -// CHECK: !7 = !{!"XLA global AA domain"} -// CHECK: !8 = !{!9} -// CHECK: !9 = !{!"buffer: {index:4, offset:0, size:4}", !7} -// CHECK: !10 = !{!11} -// CHECK: !11 = !{!"buffer: {index:6, offset:0, size:4}", !7} -// CHECK: !12 = !{!13} -// CHECK: !13 = !{!"buffer: {index:5, offset:0, size:4}", !7} HloModule TensorFlowScatter_Mul @@ -231,7 +216,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* align 64 dereferenceable(16) %alloc0, i8* align 16 dereferenceable(16) %alloc1, i8* align 16 dereferenceable(4) %alloc2, i8* align 16 dereferenceable(4) %alloc3) { +// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) { // CHECK: entry: // CHECK: %[[VAL_146:.*]] = alloca i32, align 4 // CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0 @@ -253,7 +238,7 @@ ENTRY main { // CHECK: scatter_ScalarUpdate.in_bounds-after: ; preds = %[[VAL_138:.*]], %[[VAL_139:.*]] // CHECK: ret void // CHECK: scatter_ScalarUpdate.in_bounds-true: ; preds = %[[VAL_139]] -// CHECK: %[[VAL_140:.*]] = load i32, i32* %[[VAL_126]], align 4, !invariant.load !3, !noalias !4 +// CHECK: %[[VAL_140:.*]] = load i32, i32* %[[VAL_126]], align 4, !invariant.load !3 // CHECK: %[[VAL_141:.*]] = add i32 0, %[[VAL_140]] // CHECK: %[[VAL_142:.*]] = icmp ult i32 %[[VAL_140]], 4 // CHECK: %[[VAL_143:.*]] = and i1 true, %[[VAL_142]] @@ -262,7 +247,7 @@ ENTRY main { // CHECK: br label %[[VAL_137]] // CHECK: scatter.in_bounds-true: ; preds = %[[VAL_136]] // CHECK: %[[VAL_145:.*]] = getelementptr inbounds [4 x i32], [4 x i32]* %[[VAL_120]], i32 0, i32 %[[VAL_141]] -// CHECK: %[[VAL_147:.*]] = load i32, i32* %[[VAL_129]], align 4, !invariant.load !3, !noalias !4 +// CHECK: %[[VAL_147:.*]] = load i32, i32* %[[VAL_129]], align 4, !invariant.load !3 // CHECK: store i32 %[[VAL_147]], i32* %[[VAL_146]], align 4 // CHECK: %[[VAL_148:.*]] = load i32, i32* %[[VAL_146]], align 4 // CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4 @@ -272,9 +257,6 @@ ENTRY main { // CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{} -// CHECK: !4 = !{!5} -// CHECK: !5 = !{!"buffer: {index:0, offset:0, size:16}", !6} -// CHECK: !6 = !{!"XLA global AA domain"} HloModule ScalarUpdate diff --git a/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo b/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo new file mode 100644 index 00000000000..272c9a25769 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo @@ -0,0 +1,394 @@ +// RUN: hlo_to_llvm_ir %s | FileCheck %s + +HloModule TestModule + +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + +// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2 +// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1 +// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]] +// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3 +// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]] +// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: call void @compare(float* [[TMP12]], float* [[TMP13]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP14:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP14]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4 +// CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[TMP13]], align 4 +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: store float [[TMP16]], float* [[TMP18]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] + +// CHECK: define internal void @compare(float* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[LT_TYPED:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[P_0_LHS_TYPED]], align 4 +// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[P_0_RHS_TYPED]], align 4 +// CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]] +// CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8 +// CHECK-NEXT: store i8 [[TMP3]], i8* [[LT_TYPED]], align 1 +// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[LT_TYPED]], align 1 +// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG]], align 1 +// CHECK-NEXT: ret void + +// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) { +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP7:%.*]] = xor i64 [[TMP4]], 3 +// CHECK-NEXT: [[TMP8:%.*]] = icmp slt i64 [[TMP4]], [[TMP7]] +// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], 3 +// CHECK-NEXT: [[TMP10:%.*]] = and i1 [[TMP8]], [[TMP9]] +// CHECK-NEXT: br i1 [[TMP10]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: call void @compare(float* [[TMP11]], float* [[TMP12]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP13:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP13]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP14:%.*]] = load float, float* [[TMP11]], align 4 +// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4 +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: store float [[TMP14]], float* [[TMP16]], align 4 +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] + +// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) { +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2 +// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1 +// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]] +// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3 +// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]] +// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: call void @compare(float* [[TMP12]], float* [[TMP13]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP14:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP14]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4 +// CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[TMP13]], align 4 +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: store float [[TMP16]], float* [[TMP18]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] +ENTRY main { + x = f32[2, 3] parameter(0) + ROOT sort = f32[2, 3] sort(x), dimensions={1}, to_apply=compare +} + +// ----- + +HloModule TestModule + +compare { + p.0.lhs = s32[] parameter(0) + p.0.rhs = s32[] parameter(1) + p.1.lhs = f32[] parameter(2) + p.1.rhs = f32[] parameter(3) + ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT +} + +// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]* +// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0 +// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3]], i64 0 +// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2 +// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1 +// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]] +// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3 +// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]] +// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: call void @compare(i32* [[TMP12]], i32* [[TMP13]], float* [[TMP14]], float* [[TMP15]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP16:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP16]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP17:%.*]] = load i32, i32* [[TMP12]], align 4 +// CHECK-NEXT: [[TMP18:%.*]] = load i32, i32* [[TMP13]], align 4 +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store i32 [[TMP17]], i32* [[TMP19]], align 4 +// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: store i32 [[TMP18]], i32* [[TMP20]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = load float, float* [[TMP14]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = load float, float* [[TMP15]], align 4 +// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store float [[TMP21]], float* [[TMP23]], align 4 +// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: store float [[TMP22]], float* [[TMP24]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] + +// CHECK: define internal void @compare(i32* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], i32* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[LT_TYPED:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[P_1_LHS_TYPED]], align 4 +// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[P_1_RHS_TYPED]], align 4 +// CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]] +// CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8 +// CHECK-NEXT: store i8 [[TMP3]], i8* [[LT_TYPED]], align 1 +// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[LT_TYPED]], align 1 +// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG]], align 1 +// CHECK-NEXT: ret void + +// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]* +// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2:%.*]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3:%.*]], i64 0 +// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP7:%.*]] = xor i64 [[TMP4]], 3 +// CHECK-NEXT: [[TMP8:%.*]] = icmp slt i64 [[TMP4]], [[TMP7]] +// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], 3 +// CHECK-NEXT: [[TMP10:%.*]] = and i1 [[TMP8]], [[TMP9]] +// CHECK-NEXT: br i1 [[TMP10]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: call void @compare(i32* [[TMP11]], i32* [[TMP12]], float* [[TMP13]], float* [[TMP14]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP15:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP15]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP16:%.*]] = load i32, i32* [[TMP11]], align 4 +// CHECK-NEXT: [[TMP17:%.*]] = load i32, i32* [[TMP12]], align 4 +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: store i32 [[TMP16]], i32* [[TMP18]], align 4 +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store i32 [[TMP17]], i32* [[TMP19]], align 4 +// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP13]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = load float, float* [[TMP14]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4 +// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store float [[TMP21]], float* [[TMP23]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] + +// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK-NEXT: entry: +// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]* +// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2:%.*]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3:%.*]], i64 0 +// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK: sort.in_bounds-after: +// CHECK-NEXT: [[TMP7:%.*]] = bitcast [2 x [3 x i32]]* [[SORT_TYPED2]] to i8* +// CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[SORT_TYPED]], i64 0, i64 0 +// CHECK-NEXT: store i8* [[TMP7]], i8** [[TMP8]], align 8 +// CHECK-NEXT: [[TMP9:%.*]] = bitcast [2 x [3 x float]]* [[SORT_TYPED4]] to i8* +// CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[SORT_TYPED]], i64 0, i64 1 +// CHECK-NEXT: store i8* [[TMP9]], i8** [[TMP10]], align 8 +// CHECK-NEXT: ret void +// CHECK: sort.in_bounds-true: +// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP4]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1 +// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]] +// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3 +// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]] +// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK: smaller_comparison_index-after: +// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] +// CHECK: smaller_comparison_index-true: +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP12]] +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP11]] +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP12]] +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP11]] +// CHECK-NEXT: call void @compare(i32* [[TMP16]], i32* [[TMP17]], float* [[TMP18]], float* [[TMP19]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP20:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP20]], 0 +// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] +// CHECK: is_smaller_than-after: +// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] +// CHECK: is_smaller_than-true: +// CHECK-NEXT: [[TMP21:%.*]] = load i32, i32* [[TMP16]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = load i32, i32* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP11]] +// CHECK-NEXT: store i32 [[TMP21]], i32* [[TMP23]], align 4 +// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP12]] +// CHECK-NEXT: store i32 [[TMP22]], i32* [[TMP24]], align 4 +// CHECK-NEXT: [[TMP25:%.*]] = load float, float* [[TMP18]], align 4 +// CHECK-NEXT: [[TMP26:%.*]] = load float, float* [[TMP19]], align 4 +// CHECK-NEXT: [[TMP27:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP11]] +// CHECK-NEXT: store float [[TMP25]], float* [[TMP27]], align 4 +// CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP12]] +// CHECK-NEXT: store float [[TMP26]], float* [[TMP28]], align 4 +// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] +ENTRY main { + x = s32[2, 3] parameter(0) + y = f32[2, 3] parameter(1) + ROOT sort = (s32[2, 3], f32[2, 3]) sort(x, y), dimensions={1}, to_apply=compare +} diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 0a5382291c9..7a9fedec629 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -69,10 +69,12 @@ class Thunk { }; struct ThunkInfo { + // Optional. It's only used by subclasses which haven't been migrated away + // from HloInstructions. Once the migration is done, Thunks should be fully + // serializable. const HloInstruction* hlo_instruction = nullptr; absl::optional profile_index; - // TODO(timshen): Remove hlo_instruction and add name(), - // profile_annotation() here. + std::string profile_annotation; }; // The hlo_instruction argument is meant to be the instruction this thunk was @@ -80,9 +82,8 @@ class Thunk { // to Thunk::hlo_instruction, so it can be null. explicit Thunk(Kind kind, ThunkInfo thunk_info) : kind_(kind), - hlo_instruction_(thunk_info.hlo_instruction), - name_(hlo_instruction_ ? hlo_instruction_->name() : ""), - profile_index_(thunk_info.profile_index) {} + profile_index_(thunk_info.profile_index), + profile_annotation_(thunk_info.profile_annotation) {} virtual ~Thunk() {} Thunk(const Thunk&) = delete; Thunk& operator=(const Thunk&) = delete; @@ -90,19 +91,6 @@ class Thunk { Kind kind() const { return kind_; } string profile_annotation() const { return profile_annotation_; } - absl::string_view name() const { return name_; } - - // Constructs and caches the profile annotation string for this thunk and - // any child thunks. - virtual void ComputeAnnotations() { - const HloInstruction* hlo = hlo_instruction(); - if (hlo) { - profile_annotation_ = - absl::StrFormat("Thunk:#hlo_op=%s,hlo_module=%s#", hlo->name(), - hlo->GetModule()->name()); - } - } - // Prepares the thunk for execution on the given StreamExecutor. // // This may be called multiple times. Its main purpose is to give us a chance @@ -134,14 +122,8 @@ class Thunk { virtual Status ExecuteOnStream(const ExecuteParams& params) = 0; protected: - const HloInstruction* hlo_instruction() const { return hlo_instruction_; } - absl::optional profile_index() const { return profile_index_; } - const HloModuleConfig& GetModuleConfig() const { - return hlo_instruction()->GetModule()->config(); - } - // Safely copies the given buffer to the GPU, deleting it on the host only // after the copy has completed. template @@ -156,13 +138,8 @@ class Thunk { private: Kind kind_; - - // Will be removed in the future, as Thunk is migrating away from the - // monolithic HloInstruction. - const HloInstruction* hlo_instruction_; - std::string name_; absl::optional profile_index_; - string profile_annotation_; + std::string profile_annotation_; }; // A sequence of thunks. diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index 089d70d658f..690d0c9de56 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -386,6 +386,8 @@ Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo( CHECK(hlo); Thunk::ThunkInfo info; info.hlo_instruction = hlo; + info.profile_annotation = absl::StrFormat( + "Thunk:#hlo_op=%s,hlo_module=%s#", hlo->name(), hlo->GetModule()->name()); return info; } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index 3801dc8aee8..ceae39583f2 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -80,8 +80,8 @@ class ThunkSchedule { // `thunk`. // // Precondition: `operand` is a non-trivial (i.e. excluding - // thunk.hlo_instruction() itself) transitive operand of - // thunk.hlo_instruction(). + // thunk.hlo_instruction_ itself) transitive operand of + // thunk.hlo_instruction_. void AddDependenciesOnTransitiveOperands( const Thunk& thunk, const HloInstruction& operand, const absl::flat_hash_map& hlo_to_thunk); diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 47a24552b6c..792479df4ac 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -29,6 +29,7 @@ WhileThunk::WhileThunk( std::unique_ptr condition_thunk_sequence, std::unique_ptr body_thunk_sequence) : Thunk(Kind::kWhile, thunk_info), + hlo_instruction_(thunk_info.hlo_instruction), condition_result_buffer_index_(condition_result_buffer_index), // Pass nullptr as the HloInstruction* to the condition_thunk_sequence_ // and body_thunk_sequence_ constructors because these SequentialThunks @@ -39,12 +40,6 @@ WhileThunk::WhileThunk( body_thunk_sequence_(absl::make_unique( ThunkInfo(), std::move(*body_thunk_sequence))) {} -void WhileThunk::ComputeAnnotations() { - Thunk::ComputeAnnotations(); - condition_thunk_sequence_->ComputeAnnotations(); - body_thunk_sequence_->ComputeAnnotations(); -} - Status WhileThunk::Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) { TF_RETURN_IF_ERROR( @@ -67,7 +62,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { profiler.StartHloComputation(); VLOG(3) << "Executing condition computation"; TF_RETURN_IF_ERROR(condition_thunk_sequence_->ExecuteOnStream(params)); - profiler.FinishHloComputation(hlo_instruction()->while_condition()); + profiler.FinishHloComputation(hlo_instruction_->while_condition()); // Copy the result of condition computation and break the loop if 'false'. bool condition_result; @@ -91,7 +86,7 @@ Status WhileThunk::ExecuteOnStream(const ExecuteParams& params) { // Invoke thunk sequence for while 'body' computation, and pass on // 'profiler' to measure the timing of the thunks in 'body_thunk_sequence_'. TF_RETURN_IF_ERROR(body_thunk_sequence_->ExecuteOnStream(params)); - profiler.FinishHloComputation(hlo_instruction()->while_body()); + profiler.FinishHloComputation(hlo_instruction_->while_body()); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 72d9415b309..707bac15bb2 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -46,12 +46,12 @@ class WhileThunk : public Thunk { WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; - void ComputeAnnotations() override; Status Initialize(const GpuExecutable& executable, se::StreamExecutor* executor) override; Status ExecuteOnStream(const ExecuteParams& params) override; private: + const HloInstruction* hlo_instruction_; const BufferAllocation::Slice condition_result_buffer_index_; std::unique_ptr condition_thunk_sequence_; std::unique_ptr body_thunk_sequence_; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 960f60fe882..17a7b18c84b 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 72 +// Next ID: 73 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -248,6 +248,9 @@ message HloInstructionProto { // RNG algorithm used by kRngBitGenerator. xla.RandomAlgorithm rng_algorithm = 70; + + // The comparison type used for kCompare. + string comparison_type = 72; } // Serialization of HloComputation. @@ -283,6 +286,16 @@ message HloScheduleProto { map sequences = 1; } +enum Kind { + // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 + // behavior and missing has_*() APIs. + UNDEFINED_ALIAS = 0; + // The buffers may or may not alias at runtime. + MAY_ALIAS = 1; + // The buffers must alias at runtime. + MUST_ALIAS = 2; +} + message HloInputOutputAliasProto { // The following proto describes a pair of aliased an input // (described by parameter number and a ShapeIndex of the parameter) @@ -304,8 +317,8 @@ message HloInputOutputAliasProto { int64 parameter_number = 2; // ShapeIndex of the parameter instruction. repeated int64 parameter_shape_index = 3; - reserved 4; - reserved "kind"; + // The kind of alias to be setup. + Kind kind = 4; } repeated AliasEntryProto entries = 1; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 022046209bf..d640007886c 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -509,7 +509,7 @@ class HloComputation { enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( - const HloComputation::ChannelDependencyGroup& channel_dependency_map, + const HloComputation::ChannelDependencyGroup& channel_dependency_group, std::vector* post_order, HloInstruction* root, absl::flat_hash_map* visited) const; diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 0f5267e9fbc..4ba67888409 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/comparators.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" @@ -258,6 +259,11 @@ HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo, PrimitiveType type) { CHECK_NE(hlo->shape().element_type(), type); Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type); + // PRED are stored as one byte, PRED have a BitWidth of 1, avoid this problem + // by using a convert instead of bitcast convert. + if (type == PRED || hlo->shape().element_type() == PRED) { + return MakeConvertToHlo(hlo, type); + } hlo = hlo->parent()->AddInstruction( HloInstruction::CreateBitcastConvert(shape, hlo)); CHECK_EQ(hlo->shape().element_type(), type); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index ae8f49df4b4..acccf7aac9a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -440,6 +440,10 @@ Status HloEvaluator::HandleSetDimensionSize( Literal result(set_dimension_size->shape()); memcpy(result.untyped_data(), operand_literal.untyped_data(), operand_literal.size_bytes()); + const Literal& size_literal = + GetEvaluatedLiteralFor(set_dimension_size->operand(1)); + result.SetDynamicSize(set_dimension_size->dimension(), + size_literal.Get({})); evaluated_[set_dimension_size] = std::move(result); return Status::OK(); } @@ -1569,9 +1573,9 @@ class OutputBatchIndexToInputIndex { int64 index_vector_dim = dim_numbers_.index_vector_dim(); for (int64 i = 0, e = index_vector_.size(); i < e; i++) { index_vector_index_[index_vector_dim] = i; - // TODO(george): OK what should happen here? - // seems OK to crash though. - index_vector_[i] = *start_indices_.GetIntegralAsS64(index_vector_index_); + auto start_index = start_indices_.GetIntegralAsS64(index_vector_index_); + TF_RET_CHECK(start_index.has_value()); + index_vector_[i] = *start_index; } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 1a154f32a6f..d5f0c62adc1 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/meta/type_traits.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" @@ -47,22 +48,26 @@ template struct is_complex_t : absl::disjunction, std::is_same> {}; +namespace detail { +template +using unsigned_promoted_type_t = + std::make_unsigned_t() + std::declval())>; +} + // ToArithmeticSafeType(T t): -// - converts `t` to the bitwise-equivalent `unsigned T` if T is a signed +// - converts `t` to an unsigned integer at least as wide as `int` if T is an // integer, and // - otherwise returns `t` unchanged. // // It's UB in C++ to under/overflow a signed integer, so we wrap all arithmetic // in this type to force 2's complement behavior. template ::value && - std::is_signed::value>::type* = nullptr> -typename std::make_unsigned::type ToArithmeticSafeType(T t) { - return static_cast::type>(t); + typename std::enable_if::value>::type* = nullptr> +detail::unsigned_promoted_type_t ToArithmeticSafeType(T t) { + return static_cast>(t); } template ::value || - !std::is_signed::value>::type* = nullptr> + typename std::enable_if::value>::type* = nullptr> T ToArithmeticSafeType(T t) { return std::move(t); } @@ -1076,13 +1081,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleConvolution(HloInstruction* conv) override { - auto lhs = conv->operand(0); - auto rhs = conv->operand(1); + Status HandleConvolutionWithLiterals(HloInstruction* conv, + const Literal& lhs_literal, + const Literal& rhs_literal) { const auto& window = conv->window(); const Shape& result_shape = conv->shape(); - const Shape& lhs_shape = lhs->shape(); - const Shape& rhs_shape = rhs->shape(); + const Shape& lhs_shape = lhs_literal.shape(); + const Shape& rhs_shape = rhs_literal.shape(); TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); @@ -1098,24 +1103,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); - const auto lhs_rank = lhs_shape.rank(); - const auto rhs_rank = rhs_shape.rank(); - - CHECK_EQ(num_spatial_dims + 2, lhs_rank); - CHECK_EQ(num_spatial_dims + 2, rhs_rank); - - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, conv->feature_group_count(), - conv->batch_group_count(), window, dnums)); - CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) - << "return shape set to: " << ShapeUtil::HumanString(result_shape) - << " but is inferred to be: " - << ShapeUtil::HumanString(inferred_return_shape); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - std::vector window_dimension_sizes; for (auto i : dnums.kernel_spatial_dimensions()) { window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i)); @@ -1271,9 +1258,68 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleConvolution(HloInstruction* conv) override { + auto lhs = conv->operand(0); + auto rhs = conv->operand(1); + const auto& window = conv->window(); + const Shape& result_shape = conv->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + + TF_CHECK_OK(ShapeUtil::ValidateShape(lhs_shape)); + TF_CHECK_OK(ShapeUtil::ValidateShape(rhs_shape)); + CHECK(lhs_shape.IsArray()); + CHECK(rhs_shape.IsArray()); + + const auto& dnums = conv->convolution_dimension_numbers(); + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); + CHECK_GE(num_spatial_dims, 0); + CHECK_EQ(window.dimensions_size(), num_spatial_dims); + + const auto lhs_rank = lhs_shape.rank(); + const auto rhs_rank = rhs_shape.rank(); + + CHECK_EQ(num_spatial_dims + 2, lhs_rank); + CHECK_EQ(num_spatial_dims + 2, rhs_rank); + + TF_ASSIGN_OR_RETURN(auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, conv->feature_group_count(), + conv->batch_group_count(), window, dnums)); + CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) + << "return shape set to: " << ShapeUtil::HumanString(result_shape) + << " but is inferred to be: " + << ShapeUtil::HumanString(inferred_return_shape); + + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + const bool lhs_same = ShapeUtil::SameElementType(lhs_shape, result_shape); + const bool rhs_same = ShapeUtil::SameElementType(rhs_shape, result_shape); + if (rhs_same && lhs_same) { + return HandleConvolutionWithLiterals(conv, lhs_literal, rhs_literal); + } + if (rhs_same) { + return HandleConvolutionWithLiterals( + conv, lhs_literal.Convert(result_shape.element_type()).ValueOrDie(), + rhs_literal); + } + if (lhs_same) { + return HandleConvolutionWithLiterals( + conv, lhs_literal, + rhs_literal.Convert(result_shape.element_type()).ValueOrDie()); + } + return HandleConvolutionWithLiterals( + conv, lhs_literal.Convert(result_shape.element_type()).ValueOrDie(), + rhs_literal.Convert(result_shape.element_type()).ValueOrDie()); + } + Status HandleDot(HloInstruction* dot) override { if (dot->dot_dimension_numbers().rhs_contracting_dimensions_size() == 1 && - parent_->use_fast_path_) { + parent_->use_fast_path_ && + ShapeUtil::SameElementType(dot->operand(0)->shape(), dot->shape()) && + ShapeUtil::SameElementType(dot->operand(1)->shape(), dot->shape())) { return HandleDot(dot); } return HandleDotSlowPath(dot); @@ -1342,23 +1388,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleDotSlowPath(dot); } - Status HandleDotSlowPath(HloInstruction* dot) { - auto lhs = dot->operand(0); - auto rhs = dot->operand(1); - CHECK(dot->shape().IsArray()); - CHECK(lhs->shape().IsArray()); - CHECK(rhs->shape().IsArray()); - + Status HandleDotSlowPathWithLiterals(HloInstruction* dot, + const Literal& lhs_literal, + const Literal& rhs_literal) { const auto& dnums = dot->dot_dimension_numbers(); - const auto lhs_rank = lhs->shape().rank(); - const auto rhs_rank = rhs->shape().rank(); + const auto lhs_rank = lhs_literal.shape().rank(); + const auto rhs_rank = rhs_literal.shape().rank(); - CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); - CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); - - const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); - const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + CHECK(ShapeUtil::SameElementType(lhs_literal.shape(), rhs_literal.shape())); + CHECK(ShapeUtil::SameElementType(lhs_literal.shape(), dot->shape())); CHECK_EQ(dnums.lhs_batch_dimensions_size(), dnums.rhs_batch_dimensions_size()); @@ -1406,7 +1445,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 rhs_dnum = dnums.rhs_contracting_dimensions(i); accumulate_index_locations.push_back( {&lhs_index[lhs_dnum], &rhs_index[rhs_dnum]}); - const int64 dim_size = lhs->shape().dimensions(lhs_dnum); + const int64 dim_size = lhs_literal.shape().dimensions(lhs_dnum); accumulate_index_sizes.push_back(dim_size); } const int64 total_contraction_size = Product(accumulate_index_sizes); @@ -1457,6 +1496,36 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleDotSlowPath(HloInstruction* dot) { + auto lhs = dot->operand(0); + auto rhs = dot->operand(1); + CHECK(dot->shape().IsArray()); + CHECK(lhs->shape().IsArray()); + CHECK(rhs->shape().IsArray()); + const bool lhs_same = + ShapeUtil::SameElementType(lhs->shape(), dot->shape()); + const bool rhs_same = + ShapeUtil::SameElementType(rhs->shape(), dot->shape()); + const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); + const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); + if (lhs_same && rhs_same) { + return HandleDotSlowPathWithLiterals(dot, lhs_literal, rhs_literal); + } + if (lhs_same) { + return HandleDotSlowPathWithLiterals( + dot, lhs_literal, + rhs_literal.Convert(dot->shape().element_type()).ValueOrDie()); + } + if (rhs_same) { + return HandleDotSlowPathWithLiterals( + dot, lhs_literal.Convert(dot->shape().element_type()).ValueOrDie(), + rhs_literal); + } + return HandleDotSlowPathWithLiterals( + dot, lhs_literal.Convert(dot->shape().element_type()).ValueOrDie(), + rhs_literal.Convert(dot->shape().element_type()).ValueOrDie()); + } + Status HandlePad(HloInstruction* pad) override { CHECK(pad->operand(0)->shape().IsArray()); // Padding value must be scalar. @@ -2344,39 +2413,23 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } // Enable CLZ only for int32, uint32, int64 and uint64. - template < - typename NativeT, - typename std::enable_if< - (std::is_floating_point::value || - std::is_integral::value || is_complex_t::value) && - !(std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value)>::type* = nullptr> + template ::value || + std::is_same::value)>::type* = nullptr> Status HandleClz(HloInstruction* clz) { return UnsupportedTypeError(clz); } template ::value || - std::is_same::value>::type* = nullptr> + std::is_integral::value && + !std::is_same::value>::type* = nullptr> Status HandleClz(HloInstruction* clz) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { - return 31 - tensorflow::Log2Floor(elem_operand); - })); - return Status::OK(); - } - - template ::value || - std::is_same::value>::type* = nullptr> - Status HandleClz(HloInstruction* clz) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[clz], - ElementWiseUnaryOp(clz, [](ElementwiseT elem_operand) { - return 63 - tensorflow::Log2Floor64(elem_operand); + return (sizeof(elem_operand) * CHAR_BIT - 1) - + tensorflow::Log2Floor64(elem_operand); })); return Status::OK(); } @@ -2385,23 +2438,18 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleClz(clz); } - // Enable Popcnt only for int32, uint32, int64 and uint64. template ::value || - std::is_same::value || - std::is_same::value || - std::is_same::value)>::type* = nullptr> + (!std::is_integral::value || + std::is_same::value)>::type* = nullptr> Status HandlePopulationCount(HloInstruction* popcnt) { return UnsupportedTypeError(popcnt); } template ::value || - std::is_same::value || - std::is_same::value || - std::is_same::value>::type* = nullptr> + std::is_integral::value && + !std::is_same::value>::type* = nullptr> Status HandlePopulationCount(HloInstruction* popcnt) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[popcnt], diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index a50af6bf1b9..d7e8984dee8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1181,7 +1181,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { instr_shape = StrCat( absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "..."); } - lines.push_back(instr_shape); + lines.push_back(HtmlLikeStringSanitize(instr_shape)); } if (debug_options_.xla_hlo_graph_addresses()) { lines.push_back(StrFormat("[%p]", instr)); diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc index e123161720b..34bc30d641f 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_module.h" namespace xla { @@ -24,9 +25,10 @@ bool HloInputOutputAliasConfig::OutputHasAlias( return alias_.element(output_index).has_value(); } -Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, - int64 param_number, - const ShapeIndex& param_index) { +Status HloInputOutputAliasConfig::SetUpAlias( + const ShapeIndex& output_index, int64 param_number, + const ShapeIndex& param_index, + HloInputOutputAliasConfig::AliasKind must_alias) { TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index)) << "Trying to set up alias at " << output_index.ToString() << " which is an invalid index for shape " @@ -41,7 +43,8 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index, param_number, param_index.ToString(), output_index.ToString(), alias_.element(output_index)->parameter_number, alias_.element(output_index)->parameter_index.ToString()); - (*alias_.mutable_element(output_index)) = Alias(param_number, param_index); + (*alias_.mutable_element(output_index)) = + Alias(param_number, param_index, must_alias); VLOG(4) << "Set up alias between output index " << output_index.ToString() << " and parameter " << param_index << " at index " << param_index.ToString(); @@ -61,6 +64,11 @@ HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const { for (int64 i : data->parameter_index) { entry.add_parameter_shape_index(i); } + if (data->must_alias()) { + entry.set_kind(Kind::MUST_ALIAS); + } else { + entry.set_kind(Kind::MAY_ALIAS); + } result.add_entries()->Swap(&entry); } }); @@ -77,8 +85,9 @@ StatusOr HloInputOutputAliasConfig::CreateFromProto( int64 param_number = entry.parameter_number(); ShapeIndex param_index(entry.parameter_shape_index().begin(), entry.parameter_shape_index().end()); + AliasKind kind = entry.kind() == Kind::MAY_ALIAS ? kMayAlias : kMustAlias; TF_RETURN_IF_ERROR( - result.SetUpAlias(output_index, param_number, param_index)); + result.SetUpAlias(output_index, param_number, param_index, kind)); } return result; } @@ -93,9 +102,9 @@ string HloInputOutputAliasConfig::ToString() const { ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) { pieces.push_back(absl::StrFormat( - " OutputIndex %s is aliased with parameter %lld at %s:", - output_index.ToString(), alias.parameter_number, - alias.parameter_index.ToString())); + " OutputIndex %s is %saliased with parameter %lld at %s:", + output_index.ToString(), alias.kind == kMustAlias ? "must-" : "may-", + alias.parameter_number, alias.parameter_index.ToString())); }); return absl::StrJoin(pieces, "\n"); } @@ -112,6 +121,19 @@ string HloInputOutputAliasConfig::ToShortString() const { return absl::StrJoin(pieces, ", "); } +bool HloInputOutputAliasConfig::ParameterMustAlias( + int64 param_number, const ShapeIndex& param_index) const { + bool result = false; + alias_.ForEachElement( + [&](const xla::ShapeIndex&, absl::optional alias) { + if (alias && alias->parameter_number == param_number && + alias->parameter_index == param_index && alias->must_alias()) { + result = true; + } + }); + return result; +} + absl::optional HloInputOutputAliasConfig::GetAliasedOutput( int64 param_number, const ShapeIndex& param_index) const { absl::optional output; diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h index d5ca28e9387..6b84bdb6a68 100644 --- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h +++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config.h @@ -32,22 +32,32 @@ class HloModule; // parameter index in the entry computation. class HloInputOutputAliasConfig { public: + // The kind of aliases which can be set. A kMayAlias is one setup at + // compilation time by the user, and has to be respected. A kMustAlias one + // might be setup by the compiler, if it decides it is convenient to do so. + enum AliasKind { + kMayAlias, + kMustAlias, + }; // Defines the alias information for a given output buffer. A given output // buffer shape index can refer only to one parameter+index. struct Alias { - Alias(int64 parameter_number, ShapeIndex parameter_index) + Alias(int64 parameter_number, ShapeIndex parameter_index, + AliasKind kind = kMayAlias) : parameter_number(parameter_number), - parameter_index(std::move(parameter_index)) {} + parameter_index(std::move(parameter_index)), + kind(kind) {} int64 parameter_number; ShapeIndex parameter_index; + AliasKind kind; + + bool must_alias() const { return kind == kMustAlias; } std::string ToString() { - if (parameter_index.empty()) { - return absl::StrCat(parameter_number); - } - return absl::StrFormat("(%lld, %s)", parameter_number, - parameter_index.ToString()); + return absl::StrFormat("(%lld, %s, %s)", parameter_number, + parameter_index.ToString(), + kind == kMustAlias ? "must_alias" : "may_alias"); } }; @@ -61,7 +71,8 @@ class HloInputOutputAliasConfig { // Sets up alias config from `output_index` to `param_index` at // `param_number`. Status SetUpAlias(const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index); + const ShapeIndex& param_index, + AliasKind must_alias = kMayAlias); // Returns true if the given parameter is aliased with one of the output // buffers. @@ -92,6 +103,11 @@ class HloInputOutputAliasConfig { absl::optional GetAliasedParameter( const ShapeIndex& output_index) const; + // Returns if the parameter at the given parameter number and parameter + // index must-alias with an output. + bool ParameterMustAlias(int64 param_number, + const ShapeIndex& param_index) const; + using AliasFn = std::function; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9957df41f1a..2ce3c12b4e9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -174,8 +174,19 @@ StatusOr> HloInstruction::CreateFromProto( comparison_direction, StringToComparisonDirection(proto.comparison_direction())); } - instruction = - CreateCompare(shape, operands(0), operands(1), *comparison_direction); + auto comparison_type_str = proto.comparison_type(); + if (!comparison_type_str.empty()) { + // If a comparison type is specified, it *must* be valid. + TF_ASSIGN_OR_RETURN(auto comparison_type, + StringToComparisonType(comparison_type_str)); + instruction = CreateCompare(shape, operands(0), operands(1), + *comparison_direction, comparison_type); + } else { + // Allow the specify of comparison type to be optional. + // The comparison type will be determined by the types of the operands. + instruction = CreateCompare(shape, operands(0), operands(1), + *comparison_direction); + } break; } case HloOpcode::kTriangularSolve: { @@ -926,8 +937,9 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, /* static */ std::unique_ptr HloInstruction::CreateCompare( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - ComparisonDirection direction) { - return absl::make_unique(shape, lhs, rhs, direction); + ComparisonDirection direction, absl::optional type) { + return absl::make_unique(shape, lhs, rhs, direction, + type); } /* static */ std::unique_ptr @@ -1750,10 +1762,10 @@ void HloInstruction::DetachFromOperandsAndUsers() { } } -std::unique_ptr HloInstruction::Clone( - const string& suffix, HloCloneContext* context) const { +std::unique_ptr HloInstruction::CloneWithNewShape( + const Shape& shape, const string& suffix, HloCloneContext* context) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, context); + CloneWithNewOperands(shape, operands_, context); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1790,6 +1802,13 @@ std::unique_ptr HloInstruction::Clone( return clone; } +std::unique_ptr HloInstruction::Clone( + const string& suffix, HloCloneContext* context) const { + std::unique_ptr clone = + CloneWithNewShape(shape_, suffix, context); + return clone; +} + std::pair HloInstruction::LatestNonGteAncestorAndIndex() const { const HloInstruction* hlo = this; @@ -2189,6 +2208,27 @@ Status HloInstruction::ReplaceOperandWithDifferentShape( return Status::OK(); } +Status HloInstruction::ReplaceUsesWith(absl::Span users, + HloInstruction* new_producer) { + TF_RET_CHECK( + ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape())) + << shape() << " is not compatible with " << new_producer->shape(); + return ReplaceAllUsesWithDifferentShape(users, new_producer); +} + +Status HloInstruction::ReplaceAllUsesWithDifferentShape( + absl::Span users, HloInstruction* new_producer) { + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(ReplaceUseWithDifferentShape(user, new_producer)); + } + + if (parent_ && parent_->root_instruction() == this) { + parent_->set_root_instruction(new_producer, + /*accept_different_shape=*/true); + } + return Status::OK(); +} + Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { TF_RET_CHECK( ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape())) diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8c50a9bb8fc..bdd64c908f0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -595,7 +595,8 @@ class HloInstruction { // Creates a compare op, performing the comparison specified in direction. static std::unique_ptr CreateCompare( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - Comparison::Direction direction); + Comparison::Direction direction, + absl::optional type = absl::nullopt); static std::unique_ptr CreateTriangularSolve( const Shape& shape, HloInstruction* a, HloInstruction* b, @@ -1201,6 +1202,12 @@ class HloInstruction { // Same as ReplaceAllUsesWith, but new_producer can have a different shape. Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer); + // Same as ReplaceAllUsesWith, but only replace given set of users. + Status ReplaceUsesWith(absl::Span users, + HloInstruction* new_producer); + Status ReplaceAllUsesWithDifferentShape( + absl::Span users, HloInstruction* new_producer); + // Performs a postorder DFS visit using this node as the root. If // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when // complete. If ignore_control_predecessors is true, instructions only @@ -1413,6 +1420,11 @@ class HloInstruction { std::unique_ptr Clone( const string& suffix = "clone", HloCloneContext* context = nullptr) const; + // Clones the HLO instruction as above but with new shape. + std::unique_ptr CloneWithNewShape( + const Shape& shape, const string& suffix = "clone", + HloCloneContext* context = nullptr) const; + // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( const Shape& shape, absl::Span new_operands, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 3d34fa03a80..dbc1d85d1bb 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -204,12 +204,13 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( fft_length_); } -HloCompareInstruction::HloCompareInstruction(const Shape& shape, - HloInstruction* lhs, - HloInstruction* rhs, - ComparisonDirection direction) +HloCompareInstruction::HloCompareInstruction( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + ComparisonDirection direction, absl::optional type) : HloInstruction(HloOpcode::kCompare, shape), - compare_(direction, lhs->shape().element_type()) { + compare_(direction, type ? (*type) + : Comparison::DefaultComparisonType( + lhs->shape().element_type())) { AppendOperand(lhs); AppendOperand(rhs); } @@ -218,12 +219,21 @@ HloInstructionProto HloCompareInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_comparison_direction( ComparisonDirectionToString(compare_.GetDirection())); + proto.set_comparison_type(ComparisonTypeToString(compare_.GetType())); return proto; } std::vector HloCompareInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("direction=", ComparisonDirectionToString(direction()))}; + std::vector result; + result.push_back( + StrCat("direction=", ComparisonDirectionToString(direction()))); + if (compare_.GetType() != + Comparison::DefaultComparisonType(operand(0)->shape().element_type())) { + result.push_back( + StrCat("type=", ComparisonTypeToString(compare_.GetType()))); + } + return result; } bool HloCompareInstruction::IdenticalSlowPath( @@ -238,8 +248,8 @@ std::unique_ptr HloCompareInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return absl::make_unique(shape, new_operands[0], - new_operands[1], direction()); + return absl::make_unique( + shape, new_operands[0], new_operands[1], direction(), type()); } namespace { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 51317b32bd0..3f92bb92f02 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -136,8 +136,10 @@ class HloCompareInstruction : public HloInstruction { public: explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - ComparisonDirection direction); + ComparisonDirection direction, + absl::optional type); ComparisonDirection direction() const { return compare_.GetDirection(); } + Comparison::Type type() const { return compare_.GetType(); } HloInstructionProto ToProto() const override; private: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 5502665e886..749193a83ef 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -281,6 +281,7 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(ROOT); KEYWORD(maximal); KEYWORD(replicated); + KEYWORD(last_tile_dim_replicate); #undef KEYWORD @@ -495,6 +496,8 @@ string TokKindToString(TokKind kind) { return "kw_maximal"; case TokKind::kw_replicated: return "kw_replicated"; + case TokKind::kw_last_tile_dim_replicate: + return "kw_last_tile_dim_replicate"; case TokKind::kw_nan: return "kw_nan"; case TokKind::kw_inf: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 6a59f180ad8..b8c7debaab4 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -61,6 +61,7 @@ enum class TokKind { kw_false, kw_maximal, kw_replicated, + kw_last_tile_dim_replicate, kw_nan, kw_inf, diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index c715d016c4f..4a67c1d2146 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" +#include #include #include #include @@ -442,6 +443,7 @@ StatusOr HloModule::CreateModuleConfigFromShape( } module_config.set_use_spmd_partitioning( execution_options->use_spmd_partitioning()); + module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo()); if (execution_options->has_device_assignment()) { TF_ASSIGN_OR_RETURN(std::unique_ptr device_assignment, DeviceAssignment::Deserialize( @@ -650,30 +652,28 @@ bool CompareComputationsByContent(HloComputation* a, HloComputation* b) { } // anonymous namespace std::vector HloModule::MakeComputationSorted() const { - std::vector result; - result.reserve(computations_.size()); - for (const auto& computation : computations_) { - result.push_back(computation.get()); + std::vector result = MakeComputationPostOrder(); + if (config().content_aware_computation_sorting()) { + absl::c_sort(result, CompareComputationsByContent); } - std::sort(result.begin(), result.end(), CompareComputationsByContent); return result; } std::vector HloModule::MakeNonfusionComputations() const { - std::vector result; - for (auto* c : computations()) { - if (c->IsFusionComputation()) { - continue; - } - result.push_back(c); - } + std::vector result = MakeComputationPostOrder(); + result.erase(std::remove_if( + result.begin(), result.end(), + [](HloComputation* c) { return c->IsFusionComputation(); }), + result.end()); return result; } std::vector HloModule::MakeNonfusionComputationsSorted() const { auto result = MakeNonfusionComputations(); - std::sort(result.begin(), result.end(), CompareComputationsByContent); + if (config().content_aware_computation_sorting()) { + absl::c_sort(result, CompareComputationsByContent); + } return result; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 0abf3a496f7..ae0a8aae838 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -138,6 +138,13 @@ class HloModuleConfig { } bool use_spmd_partitioning() const { return use_spmd_partitioning_; } + // If enabled, deduplicate equivalent hlos into function calls to reduce code + // size. + void set_deduplicate_hlo(bool deduplicate_hlo) { + deduplicate_hlo_ = deduplicate_hlo; + } + bool deduplicate_hlo() const { return deduplicate_hlo_; } + // Return a string which unambiguously represents all the fields of this data // structure. Used for generating a cache key for storing the compiled // executable. @@ -188,6 +195,14 @@ class HloModuleConfig { alias_passthrough_params_ = alias_passthrough_params; } + bool content_aware_computation_sorting() const { + return content_aware_computation_sorting_; + } + void set_content_aware_computation_sorting( + bool content_aware_computation_sorting) { + content_aware_computation_sorting_ = content_aware_computation_sorting; + } + FusionConfigCollection fusion_config_collection() const { return fusion_config_collection_; } @@ -238,6 +253,10 @@ class HloModuleConfig { // needs to partition the module. bool use_spmd_partitioning_ = false; + // If enabled, deduplicate equivalent hlos into function calls to reduce code + // size. + bool deduplicate_hlo_ = false; + // The target maximum parallelism at which to partition HLOs for parallel // execution on the CPU backend. int64 intra_op_parallelism_threads_ = -1; @@ -251,6 +270,8 @@ class HloModuleConfig { bool alias_passthrough_params_ = false; + bool content_aware_computation_sorting_ = false; + FusionConfigCollection fusion_config_collection_ = FusionConfigCollection::kOff; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ec0540b8607..2afa06a5df4 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -194,6 +194,7 @@ class HloParserImpl : public HloParser { kBracedHloComputationList, kFftType, kComparisonDirection, + kComparisonType, kWindow, kConvolutionDimensionNumbers, kSharding, @@ -327,6 +328,7 @@ class HloParserImpl : public HloParser { bool ParseOpcode(HloOpcode* result); bool ParseFftType(FftType* result); bool ParseComparisonDirection(ComparisonDirection* result); + bool ParseComparisonType(Comparison::Type* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); bool ParseRandomAlgorithm(RandomAlgorithm* result); @@ -552,33 +554,39 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) { return false; } - if (lexer_.GetKind() != TokKind::kLparen) { - // Short form: "{0}: 0", output index "{}" is assumed. - int64 param_num; - ParseInt64(¶m_num); - data->emplace(std::piecewise_construct, std::forward_as_tuple(out), - std::forward_as_tuple(param_num, ShapeIndex{})); - } else { - // Long form: "{0}: (0, {0})", output index is explicitly specified. - if (!ParseToken(TokKind::kLparen, errmsg)) { - return false; - } - int64 param_num; - ParseInt64(¶m_num); - if (!ParseToken(TokKind::kComma, errmsg)) { - return false; - } - ShapeIndex param_idx; - if (!ParseShapeIndex(¶m_idx)) { - return false; - } - data->emplace(std::piecewise_construct, std::forward_as_tuple(out), - std::forward_as_tuple(param_num, param_idx)); - if (!ParseToken(TokKind::kRparen, errmsg)) { - return false; + if (!ParseToken(TokKind::kLparen, errmsg)) { + return false; + } + int64 param_num; + ParseInt64(¶m_num); + if (!ParseToken(TokKind::kComma, errmsg)) { + return false; + } + ShapeIndex param_idx; + if (!ParseShapeIndex(¶m_idx)) { + return false; + } + + HloInputOutputAliasConfig::AliasKind alias_kind = + HloInputOutputAliasConfig::kMayAlias; + if (EatIfPresent(TokKind::kComma)) { + std::string type; + ParseName(&type); + if (type == "must-alias") { + alias_kind = HloInputOutputAliasConfig::kMustAlias; + } else if (type == "may-alias") { + alias_kind = HloInputOutputAliasConfig::kMayAlias; + } else { + return TokenError("Unexpected aliasing kind; expected SYSTEM or USER"); } } + data->emplace(std::piecewise_construct, std::forward_as_tuple(out), + std::forward_as_tuple(param_num, param_idx, alias_kind)); + if (!ParseToken(TokKind::kRparen, errmsg)) { + return false; + } + if (!EatIfPresent(TokKind::kComma)) { break; } @@ -624,8 +632,9 @@ bool HloParserImpl::ParseHloModule(HloModule* module) { if (aliasing_data) { HloInputOutputAliasConfig alias_config(module->result_shape()); for (auto& p : *aliasing_data) { - Status st = alias_config.SetUpAlias(p.first, p.second.parameter_number, - p.second.parameter_index); + Status st = + alias_config.SetUpAlias(p.first, p.second.parameter_number, + p.second.parameter_index, p.second.kind); if (!st.ok()) { return TokenError(st.error_message()); } @@ -1355,14 +1364,16 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, } case HloOpcode::kCompare: { optional direction; + optional type; attrs["direction"] = {/*required=*/true, AttrTy::kComparisonDirection, &direction}; + attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateCompare( - shape, operands[0], operands[1], *direction)); + shape, operands[0], operands[1], *direction, type)); break; } case HloOpcode::kCholesky: { @@ -2129,6 +2140,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; + bool last_tile_dim_replicate = false; std::vector devices; std::vector tile_assignment_dimensions; while (lexer_.GetKind() != TokKind::kRbrace) { @@ -2180,6 +2192,10 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, } break; } + case TokKind::kw_last_tile_dim_replicate: + last_tile_dim_replicate = true; + lexer_.Lex(); + break; case TokKind::kRbrace: break; default: @@ -2218,6 +2234,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, for (int64 device : devices) { sharding->add_tile_assignment_devices(device); } + sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate); } lexer_.Lex(); @@ -2674,7 +2691,9 @@ struct MinMaxFiniteValue { template <> struct MinMaxFiniteValue { - static double max() { return static_cast(bfloat16::highest()); } + static double max() { + return static_cast(Eigen::NumTraits::highest()); + } static double min() { return -max(); } }; @@ -3003,6 +3022,14 @@ bool HloParserImpl::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kComparisonType: { + Comparison::Type result; + if (!ParseComparisonType(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } case AttrTy::kEnum: { if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects an enumeration value"); @@ -3597,7 +3624,7 @@ bool HloParserImpl::ParseHloComputationList( if (!ParseHloComputation(&computation)) { return false; } - LOG(INFO) << "parsed computation " << computation->name(); + VLOG(3) << "parsed computation " << computation->name(); result->push_back(computation); return true; }; @@ -4115,7 +4142,7 @@ bool HloParserImpl::ParseFftType(FftType* result) { } bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) { - VLOG(1) << "ParseComparisonDirection"; + VLOG(3) << "ParseComparisonDirection"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects comparison direction"); } @@ -4130,6 +4157,21 @@ bool HloParserImpl::ParseComparisonDirection(ComparisonDirection* result) { return true; } +bool HloParserImpl::ParseComparisonType(Comparison::Type* result) { + VLOG(1) << "ParseComparisonType"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects comparison type"); + } + std::string val = lexer_.GetStrVal(); + auto status_or_result = StringToComparisonType(val); + if (!status_or_result.ok()) { + return TokenError(StrFormat("expects comparison type but sees: %s", val)); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + bool HloParserImpl::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(3) << "ParseFusionKind"; if (lexer_.GetKind() != TokKind::kIdent) { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 484578e5e0e..aba6aeff999 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -230,7 +230,7 @@ R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} %v2 = f32[4]{0} parameter(1), sharding={maximal device=1} - %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated} + %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, type=TOTALORDER, sharding={replicated} ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={} } @@ -512,7 +512,7 @@ R"(HloModule R4F32OverlapSmall_module %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) %rhs = f32[] parameter(1) - ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE + ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE, type=TOTALORDER } %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { @@ -2399,7 +2399,7 @@ ENTRY c2 { TEST_F(HloParserTest, SimpleAliasing) { const string original = R"( -HloModule Module, input_output_alias={ {0}: (0, {0}), {1}: (0, {1}) } +HloModule Module, input_output_alias={ {0}: (0, {0}, must-alias), {1}: (0, {1}) } ENTRY entry { %p = (f32[], f32[]) parameter(0) @@ -2413,42 +2413,13 @@ ENTRY entry { std::unique_ptr parsed_module = module.ConsumeValueOrDie(); EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}), ShapeIndex{0}); + + EXPECT_TRUE( + parsed_module->input_output_alias_config().ParameterMustAlias(0, {0})); EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}), ShapeIndex{1}); -} - -TEST_F(HloParserTest, SimpleAliasingShortForm) { - const string original = R"( -HloModule Module, input_output_alias={ {0}: 0, {1}: 1 } - -ENTRY entry { - %p0 = f32[] parameter(0) - %p1 = f32[] parameter(1) - ROOT %out = (f32[], f32[]) tuple(%p0, %p1) -} - )"; - auto module = ParseAndReturnVerifiedModule(original); - TF_ASSERT_OK(module.status()); - std::unique_ptr parsed_module = module.ConsumeValueOrDie(); - EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {}), - ShapeIndex{0}); - EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(1, {}), - ShapeIndex{1}); -} - -TEST_F(HloParserTest, SimpleAliasingShortFormError) { - const string original = R"( -HloModule Module, input_output_alias={ {0}: A, {1}: 1 } - -ENTRY entry { - %p0 = f32[] parameter(0) - %p1 = f32[] parameter(1) - ROOT %out = (f32[], f32[]) tuple(%p0, %p1) -} - )"; - ExpectHasSubstr( - ParseAndReturnUnverifiedModule(original).status().error_message(), - "expects integer"); + EXPECT_FALSE( + parsed_module->input_output_alias_config().ParameterMustAlias(0, {1})); } TEST_F(HloParserTest, NestedAliasing) { @@ -2626,6 +2597,21 @@ TEST_F(HloParserTest, ParseSharding) { EXPECT_EQ(sharding.ToString(), original); } +TEST_F(HloParserTest, ParseShardingPartialReplication) { + const string original = "{devices=[2,2]0,1,2,3 last_tile_dim_replicate}"; + TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); + EXPECT_EQ(sharding.ToString(), original); + Array group_tiling({2}); + group_tiling(0) = 0; + group_tiling(1) = 1; + std::vector group0_members({0, 1}); + std::vector group1_members({2, 3}); + EXPECT_EQ( + HloSharding::PartialTile(group_tiling, {group0_members, group1_members}) + .ToString(), + original); +} + TEST_F(HloParserTest, ParseFrontendAttributes) { const string original = R"({attr_a="test_a",attr_b="b",attr_c="s64",attr_d="a/b"})"; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 2166ecdd890..7f974a618a8 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -121,9 +121,9 @@ struct Item { bool placed = false; // To avoid an infinite loop rematerializing the same set of - // instructions ad infinitum, keep a blacklist of instructions + // instructions ad infinitum, keep a denylist of instructions // which should not be rematerialized. - bool blacklisted = false; + bool denylisted = false; // The buffers defined by this instruction. BufferIdList buffers_defined; @@ -292,8 +292,8 @@ class InstructionList { InsertBeforeInstructions(to_insert, {max_position_item->next}); } - void Blacklist(const HloInstruction* inst) { - GetItem(inst)->blacklisted = true; + void Denylist(const HloInstruction* inst) { + GetItem(inst)->denylisted = true; } private: @@ -745,7 +745,7 @@ Status MemoryUsageTracker::EndInstruction() { for (BufferId buffer_id : in_progress_item_->buffers_used) { Buffer& buffer = buffers_.at(buffer_id); buffer.unfinished_user_count--; - CHECK_GE(buffer.unfinished_user_count, 0) + TF_RET_CHECK(buffer.unfinished_user_count >= 0) << buffer.ToString() << " has negative unfinished user count."; if (buffer.unfinished_user_count == 0) { // Buffer is now dead. @@ -1158,13 +1158,13 @@ std::vector GetInitialBlock(const InstructionList& instruction_list, return item_block; } -// Returns whether any instruction in 'block' is blacklisted or +// Returns whether any instruction in 'block' is denylisted or // non-rematerializable. -bool AnyBlacklistedOrNonRematerializable( +bool AnyDenylistedOrNonRematerializable( const std::vector& block, absl::flat_hash_map* rematerializable_map) { for (auto* item : block) { - if (item->blacklisted) { + if (item->denylisted) { return true; } if (!CanBeRematerialized(item->instruction, rematerializable_map)) { @@ -1195,10 +1195,10 @@ MemoryUsageTracker::PickRematerializationCandidates( // instructions. break; } - // If any item in the starting block are blacklisted or non-rematable, then + // If any item in the starting block are denylisted or non-rematable, then // break and move on to next start_item (we can actually move to the last // invalid item in this block, but let's ignore that optimization for now). - if (AnyBlacklistedOrNonRematerializable(block, rematerializable_map)) { + if (AnyDenylistedOrNonRematerializable(block, rematerializable_map)) { continue; } while (block.size() <= max_block_size) { @@ -1289,8 +1289,8 @@ MemoryUsageTracker::PickRematerializationCandidates( // Time to update the block to include the next instruction. auto* last_item = block[block.size() - 1]; auto* next_item = instruction_list.next(last_item); - if (next_item == nullptr || next_item->blacklisted || - !next_item->placed || next_item == in_progress_item_ || + if (next_item == nullptr || next_item->denylisted || !next_item->placed || + next_item == in_progress_item_ || !CanBeRematerialized(next_item->instruction, rematerializable_map)) { break; } @@ -1404,7 +1404,7 @@ StatusOr RematerializeInstructions( // instruction it was a copying of. Now 'remat' is a rematerialization // of 'best' and kills 'best'. Stop rematerializing this instruction // to avoid an infinite loop. - instruction_list->Blacklist(remat); + instruction_list->Denylist(remat); } remat_move_instructions->insert(remat); } else { @@ -1460,8 +1460,8 @@ StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, place_before.push_back(instruction_list->GetItem(user)); } - instruction_list->Blacklist(compressed_item->instruction); - instruction_list->Blacklist(uncompressed_item->instruction); + instruction_list->Denylist(compressed_item->instruction); + instruction_list->Denylist(uncompressed_item->instruction); instruction_list->InsertBeforeInstructions(uncompressed_item, place_before); @@ -1583,7 +1583,7 @@ StatusOr HloRematerialization::RematerializeComputation( // rematerialization is added to 'remat_move_instructions' (the // rematerialization is essentially a move). If the next rematerialization of // the instruction is also a move then the rematerialization is added to the - // blacklist. + // denylist. absl::flat_hash_set remat_move_instructions; // The map from instructions to their rematerializable status. diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 30a7916c408..83130108dd7 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -211,7 +211,8 @@ static std::vector ExecutionInputsFromScopedShapedBuffers( *buffer_tree.mutable_element(index) = execution_input_buffer; } }); - execution_inputs.emplace_back(std::move(buffer_tree)); + execution_inputs.emplace_back(std::move(buffer_tree), + input_buffer.on_host_shape()); } return execution_inputs; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index b0a03707efb..92270005ffd 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -39,6 +39,47 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { return HloSharding(assignment); } +HloSharding HloSharding::PartialTile( + const Array& group_tile_assignment, + absl::Span> replication_groups) { + auto new_tile_dims = group_tile_assignment.dimensions(); + new_tile_dims.push_back(replication_groups[0].size()); + auto new_tile_assignment = Array(new_tile_dims); + new_tile_assignment.Each([&](absl::Span indices, int64* device) { + std::vector group_index(indices.begin(), indices.end()); + group_index.pop_back(); + int64 group = group_tile_assignment(group_index); + *device = replication_groups[group][indices.back()]; + }); + return PartialTile(new_tile_assignment); +} + +HloSharding HloSharding::PartialTile( + const Array& tile_assignment_last_dim_replicate) { + std::vector> sorted_groups( + tile_assignment_last_dim_replicate.num_elements() / + tile_assignment_last_dim_replicate.dimensions().back()); + auto get_group_id = [&](absl::Span indices) { + int64 group_id = 0; + for (int64 i = 0; i < indices.size() - 1; ++i) { + group_id *= tile_assignment_last_dim_replicate.dim(i); + group_id += indices[i]; + } + return group_id; + }; + tile_assignment_last_dim_replicate.Each( + [&](absl::Span indices, const int64 device) { + sorted_groups[get_group_id(indices)].insert(device); + }); + Array sorted_tile(tile_assignment_last_dim_replicate.dimensions()); + sorted_tile.Each([&](absl::Span indices, int64* device) { + auto begin = sorted_groups[get_group_id(indices)].begin(); + *device = *begin; + sorted_groups[get_group_id(indices)].erase(begin); + }); + return HloSharding(sorted_tile, /*replicate_on_last_tile_dim=*/true); +} + HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { std::vector flattened_list; flattened_list.reserve(sub_shardings.leaf_count()); @@ -101,8 +142,10 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); } - return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]", - StrJoin(tile_assignment_, ","), "}"); + return StrCat( + "{devices=[", StrJoin(tile_assignment_.dimensions(), ","), "]", + StrJoin(tile_assignment_, ","), + replicate_on_last_tile_dim_ ? " last_tile_dim_replicate}" : "}"); } bool HloSharding::UsesDevice(int64 device) const { @@ -148,6 +191,9 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { } }); CHECK(!ret_index.empty()); + if (replicate_on_last_tile_dim_) { + ret_index.pop_back(); + } return ret_index; } @@ -157,6 +203,12 @@ int64 HloSharding::DeviceForTileIndex(absl::Span index) const { if (maximal_) { return *tile_assignment_.begin(); } + if (replicate_on_last_tile_dim_ && + index.size() < tile_assignment().num_dimensions()) { + std::vector first_replicated_index(index.begin(), index.end()); + first_replicated_index.push_back(0); + return tile_assignment_(first_replicated_index); + } return tile_assignment_(index); } @@ -167,8 +219,11 @@ std::vector HloSharding::TileOffsetForDevice(const Shape& shape, if (maximal_) { return std::vector(shape.dimensions_size(), 0); } - - CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + if (replicate_on_last_tile_dim_) { + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions() - 1); + } else { + CHECK_EQ(shape.dimensions_size(), tile_assignment_.num_dimensions()); + } std::vector index = TileIndexForDevice(device); for (int64 i = 0; i < index.size(); ++i) { const int64 shape_dim = shape.dimensions(i); @@ -341,8 +396,10 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return Status::OK(); } - // The tile assignment tensor must have the same rank as the input. - if (shape.rank() != tile_assignment_.num_dimensions()) { + // The tile assignment tensor must have the same rank as the input, or input + // rank + 1 for replicate_on_last_tile_dim_. + if (shape.rank() + (replicate_on_last_tile_dim_ ? 1 : 0) != + tile_assignment_.num_dimensions()) { return tensorflow::errors::InvalidArgument( "Number of tile assignment dimensions is different to the input rank. " "sharding=", @@ -403,7 +460,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, proto.tile_assignment_dimensions().end())); std::copy(proto.tile_assignment_devices().begin(), proto.tile_assignment_devices().end(), tile_assignment.begin()); - return HloSharding(tile_assignment); + return proto.replicate_on_last_tile_dim() ? PartialTile(tile_assignment) + : HloSharding(tile_assignment); } OpSharding HloSharding::ToProto() const { @@ -429,6 +487,7 @@ OpSharding HloSharding::ToProto() const { result.set_type(OpSharding::MAXIMAL); } else { result.set_type(OpSharding::OTHER); + result.set_replicate_on_last_tile_dim(ReplicateOnLastTileDim()); } return result; } @@ -464,6 +523,17 @@ Shape HloSharding::TileShape(const Shape& shape, int64 device) const { return result_shape; } +int64 HloSharding::NumTiles() const { + if (IsTileMaximal()) { + return 1; + } + if (ReplicateOnLastTileDim()) { + return tile_assignment().num_elements() / + tile_assignment().dimensions().back(); + } + return tile_assignment().num_elements(); +} + HloSharding HloSharding::GetSubSharding(const Shape& shape, const ShapeIndex& index) const { CHECK(IsTuple()); @@ -516,6 +586,9 @@ size_t HloSharding::Hash() const { for (uint32 v : tile_assignment_) { h = tensorflow::Hash64Combine(h, std::hash{}(v)); } + if (replicate_on_last_tile_dim_) { + h = tensorflow::Hash64Combine(h, std::hash{}(1)); + } return h; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 20fa7232e65..e7ba2bc0680 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -54,6 +54,19 @@ class HloSharding { return HloSharding(tile_assignment); } + // Creates a new sharding where data is replicated within each replication + // group, and sharded across replication groups according to + // group_tile_assignment. Replication group members will be sorted. + static HloSharding PartialTile( + const Array& group_tile_assignment, + absl::Span> replication_groups); + + // Creates a partially replicated tiled sharding with device-level tile + // assignment, where the last dimension is the additional replication + // dimension. Replication group members will be sorted. + static HloSharding PartialTile( + const Array& tile_assignment_last_dim_replicate); + // Creates a new sharding which splits a one-dimensional input shape into // `num_tiles` tiles. static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); @@ -115,6 +128,11 @@ class HloSharding { }); } + // Returns if the sharding has partial replication and partial sharding. If + // true, data is sharded according to other dimensions of tile_assignment(), + // but replicated across devices along the last dimension. + bool ReplicateOnLastTileDim() const { return replicate_on_last_tile_dim_; } + // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; @@ -132,6 +150,10 @@ class HloSharding { // Returns the device that should execute the given tile. // It is an error to call this if is_replicated() is true. + // When ReplicateOnLastTileDim() == true, if index.size() == data rank, it + // returns the first device in that replicated subgroup; otherwise, + // index.size() should be the same as tile_assignment()'s rank and specifies + // the member of the replication subgroup. // REQUIRES: !IsTuple() int64 DeviceForTileIndex(absl::Span index) const; @@ -188,7 +210,8 @@ class HloSharding { bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && tile_assignment_ == other.tile_assignment_ && - tuple_elements_ == other.tuple_elements_; + tuple_elements_ == other.tuple_elements_ && + replicate_on_last_tile_dim_ == other.replicate_on_last_tile_dim_; } bool operator!=(const HloSharding& other) const { return !(*this == other); } @@ -220,12 +243,17 @@ class HloSharding { // REQUIRES: !IsTuple() Shape TileShape(const Shape& shape, int64 device) const; + // Gets the number of tiles. If it has partial replication, this will not + // equal the device count. + int64 NumTiles() const; + private: HloSharding() : replicated_(true), maximal_(true), tuple_(false), - tile_assignment_({0}) {} + tile_assignment_({0}), + replicate_on_last_tile_dim_(false) {} // device_id values: // -2: magic number to mean unassigned device, used by spatial partitioning // -1: the id of the host @@ -236,18 +264,22 @@ class HloSharding { : replicated_(false), maximal_(true), tuple_(false), - tile_assignment_({1}, device_id) {} - explicit HloSharding(const Array& tile_assignment) + tile_assignment_({1}, device_id), + replicate_on_last_tile_dim_(false) {} + explicit HloSharding(const Array& tile_assignment, + bool replicate_on_last_tile_dim = false) : replicated_(false), maximal_(false), tuple_(false), - tile_assignment_(tile_assignment) {} + tile_assignment_(tile_assignment), + replicate_on_last_tile_dim_(replicate_on_last_tile_dim) {} explicit HloSharding(const std::vector& tuple_shardings) : replicated_(false), maximal_(false), tuple_(true), tile_assignment_({0}), - tuple_elements_(tuple_shardings) {} + tuple_elements_(tuple_shardings), + replicate_on_last_tile_dim_(false) {} // Checks that the number of elements in tuple_elements_ is consistent with // the tuple shape passes as argument. @@ -283,6 +315,11 @@ class HloSharding { // present for the root. This is a flattened list of all the leaf shardings in // a tuple shape, by pre-order walk (ShapeTree iterator order). std::vector tuple_elements_; + // This flag is to support partial replication and partial sharding. If it is + // true, tile_assignment_ will have an extra dimension in addition to the data + // shape rank, and the added last dimension represents the subgroups of + // replications, i.e., elements in slice [..., :] will be replicated. + bool replicate_on_last_tile_dim_; }; std::ostream& operator<<(std::ostream& out, const HloSharding& sharding); diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc index 7fc05608800..65295a8e620 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" #include +#include #include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/array.h" @@ -23,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -190,13 +192,22 @@ absl::optional ReshapeSharding(const Shape& source_shape, target_dims_stack.push_back(t_size); } else if (s_size > t_size) { // Dimension split. - if (s_size % t_size != 0 || t_size % s_partitions != 0) { + if (s_size % t_size != 0 || s_size % s_partitions != 0) { + return absl::nullopt; + } + if (t_size % s_partitions == 0) { + target_tile_assignment_dimensions.push_back(s_partitions); + // We have part of the s_size unprocessed, so put it back to stack. + source_dims_stack.push_back(s_size / t_size); + sharding_tile_dims_stack.push_back(1); + } else if (s_partitions % t_size == 0) { + target_tile_assignment_dimensions.push_back(t_size); + // We have part of the s_size unprocessed, so put it back to stack. + source_dims_stack.push_back(s_size / t_size); + sharding_tile_dims_stack.push_back(s_partitions / t_size); + } else { return absl::nullopt; } - target_tile_assignment_dimensions.push_back(s_partitions); - // We have part of the s_size unprocessed, so put it back to stack. - source_dims_stack.push_back(s_size / t_size); - sharding_tile_dims_stack.push_back(1); } else { // Dimension merge. Also merge the source dimension with the next, and // process it next time. @@ -322,6 +333,10 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding, } } Array new_tile_assignment = index_sharding.tile_assignment(); + if (new_tile_assignment.num_elements() != + Product(output_tile_assignment_dims)) { + return HloSharding::Replicate(); + } new_tile_assignment.Reshape(output_tile_assignment_dims); return HloSharding::Tile(new_tile_assignment); } @@ -341,6 +356,10 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding, } } Array new_tile_assignment = output_sharding.tile_assignment(); + if (new_tile_assignment.num_elements() != + Product(index_tile_assignment_dims)) { + return HloSharding::Replicate(); + } new_tile_assignment.Reshape(index_tile_assignment_dims); return HloSharding::Tile(new_tile_assignment); } @@ -413,6 +432,10 @@ HloSharding ScatterIndexSharding(const HloSharding& data_sharding, index_tile_assignment_dims.push_back(1); } Array new_tile_assignment = data_sharding.tile_assignment(); + if (new_tile_assignment.num_elements() != + Product(index_tile_assignment_dims)) { + return HloSharding::Replicate(); + } new_tile_assignment.Reshape(index_tile_assignment_dims); return HloSharding::Tile(new_tile_assignment); } @@ -435,6 +458,10 @@ HloSharding ScatterDataSharding(const HloSharding& index_sharding, } } Array new_tile_assignment = index_sharding.tile_assignment(); + if (new_tile_assignment.num_elements() != + Product(data_tile_assignment_dims)) { + return HloSharding::Replicate(); + } new_tile_assignment.Reshape(data_tile_assignment_dims); return HloSharding::Tile(new_tile_assignment); } @@ -524,6 +551,169 @@ HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, return HloSharding::Tile(tile_assignment); } +namespace { + +// If partitioning in the operand only happens in dimensions in passthrough +// dimensions (offset dimensions in the gather output (or scatter update) that +// have the same size as the operand), returns the corresponding output (or +// update) sharding by passing through the input sharding. +absl::optional PassthroughOperandToGatherOutputOrScatterUpdate( + const Shape& operand_shape, const HloSharding& operand_sharding, + const Shape& update_or_gather_shape, + absl::Span collapsed_or_inserted_dims, + absl::Span index_map, + absl::Span offset_or_window_dims, + absl::Span slice_size) { + if (operand_sharding.IsTileMaximal()) { + return operand_sharding; + } + std::vector passthrough_tile(update_or_gather_shape.rank(), 1); + int64 collapsed = 0; + for (int64 i = 0; i < operand_shape.rank(); ++i) { + int64 dim_partitions = operand_sharding.tile_assignment().dim(i); + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(index_map, i)) { + if (dim_partitions > 1) { + return absl::nullopt; + } + collapsed++; + continue; + } + if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) { + return absl::nullopt; + } + int64 offset_dim = offset_or_window_dims[i - collapsed]; + if (i - collapsed > 0 && + offset_dim < offset_or_window_dims[i - collapsed - 1]) { + // Output offsets are transposed, we do not support this case. + return absl::nullopt; + } + passthrough_tile[offset_dim] = dim_partitions; + } + Array tile_assignment = operand_sharding.tile_assignment(); + tile_assignment.Reshape(passthrough_tile); + return HloSharding::Tile(tile_assignment); +} + +// Inverse of PassthroughOperandToGatherOutputOrScatterUpdate. +absl::optional PassthroughGatherOutputOrScatterUpdateToOperand( + const Shape& operand_shape, const HloSharding& update_or_gather_sharding, + absl::Span collapsed_or_inserted_dims, + absl::Span index_map, + absl::Span offset_or_window_dims, + absl::Span slice_size) { + if (update_or_gather_sharding.IsTileMaximal()) { + return update_or_gather_sharding; + } + std::vector passthrough_tile(operand_shape.rank(), 1); + int64 collapsed = 0; + for (int64 i = 0; i < operand_shape.rank(); ++i) { + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(index_map, i)) { + collapsed++; + continue; + } + int64 offset_dim = offset_or_window_dims[i - collapsed]; + int64 dim_partitions = + update_or_gather_sharding.tile_assignment().dim(offset_dim); + if (slice_size[i] != operand_shape.dimensions(i) && dim_partitions > 1) { + return absl::nullopt; + } + if (i - collapsed > 0 && + offset_dim < offset_or_window_dims[i - collapsed - 1]) { + // Output offsets are transposed, we do not support this case. + return absl::nullopt; + } + passthrough_tile[i] = dim_partitions; + } + Array tile_assignment = update_or_gather_sharding.tile_assignment(); + if (tile_assignment.num_elements() != Product(passthrough_tile)) { + return absl::nullopt; + } + tile_assignment.Reshape(passthrough_tile); + return HloSharding::Tile(tile_assignment); +} + +} // namespace + +absl::optional GatherOutputShardingFromDataOperand( + const HloSharding& data_operand_sharding, const HloInstruction& hlo) { + const auto& dnums = hlo.gather_dimension_numbers(); + std::vector collapsed_slice_dims(dnums.collapsed_slice_dims().begin(), + dnums.collapsed_slice_dims().end()); + std::vector start_index_map(dnums.start_index_map().begin(), + dnums.start_index_map().end()); + std::vector offset_dims(dnums.offset_dims().begin(), + dnums.offset_dims().end()); + return PassthroughOperandToGatherOutputOrScatterUpdate( + hlo.operand(0)->shape(), data_operand_sharding, hlo.shape(), + collapsed_slice_dims, start_index_map, offset_dims, + hlo.gather_slice_sizes()); +} + +absl::optional GatherDataOperandShardingFromOutput( + const HloSharding& output_sharding, const HloInstruction& hlo) { + const auto& dnums = hlo.gather_dimension_numbers(); + std::vector collapsed_slice_dims(dnums.collapsed_slice_dims().begin(), + dnums.collapsed_slice_dims().end()); + std::vector start_index_map(dnums.start_index_map().begin(), + dnums.start_index_map().end()); + std::vector offset_dims(dnums.offset_dims().begin(), + dnums.offset_dims().end()); + return PassthroughGatherOutputOrScatterUpdateToOperand( + hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims, + start_index_map, offset_dims, hlo.gather_slice_sizes()); +} + +absl::optional ScatterOutputShardingFromUpdate( + const HloSharding& update_sharding, const HloInstruction& hlo) { + const auto& dnums = hlo.scatter_dimension_numbers(); + std::vector inserted_window_dims(dnums.inserted_window_dims().begin(), + dnums.inserted_window_dims().end()); + std::vector scatter_dims_to_operand_dims( + dnums.scatter_dims_to_operand_dims().begin(), + dnums.scatter_dims_to_operand_dims().end()); + std::vector update_window_dims(dnums.update_window_dims().begin(), + dnums.update_window_dims().end()); + std::vector slice_size(hlo.shape().rank(), 1); + int64 num_update_window_dims = 0; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (absl::c_linear_search(dnums.inserted_window_dims(), i)) { + continue; + } + slice_size[i] = hlo.operand(2)->shape().dimensions( + dnums.update_window_dims(num_update_window_dims++)); + } + return PassthroughGatherOutputOrScatterUpdateToOperand( + hlo.shape(), update_sharding, inserted_window_dims, + scatter_dims_to_operand_dims, update_window_dims, slice_size); +} + +absl::optional ScatterUpdateShardingFromOutput( + const HloSharding& output_sharding, const HloInstruction& hlo) { + const auto& dnums = hlo.scatter_dimension_numbers(); + std::vector inserted_window_dims(dnums.inserted_window_dims().begin(), + dnums.inserted_window_dims().end()); + std::vector scatter_dims_to_operand_dims( + dnums.scatter_dims_to_operand_dims().begin(), + dnums.scatter_dims_to_operand_dims().end()); + std::vector update_window_dims(dnums.update_window_dims().begin(), + dnums.update_window_dims().end()); + std::vector slice_size(hlo.shape().rank(), 1); + int64 num_update_window_dims = 0; + for (int64 i = 0; i < hlo.shape().rank(); ++i) { + if (absl::c_linear_search(dnums.inserted_window_dims(), i)) { + continue; + } + slice_size[i] = hlo.operand(2)->shape().dimensions( + dnums.update_window_dims(num_update_window_dims++)); + } + return PassthroughOperandToGatherOutputOrScatterUpdate( + hlo.shape(), output_sharding, hlo.operand(2)->shape(), + inserted_window_dims, scatter_dims_to_operand_dims, update_window_dims, + slice_size); +} + StatusOr, HloOpcode>> IdentityValueAndHloOpcodeForScatterReduceComputation( const HloScatterInstruction& scatter) { @@ -588,5 +778,68 @@ std::vector DevicesForSharding( return devices; } +HloSharding PartiallyReplicateTiledShardingOnDims( + const HloSharding& sharding, const std::vector& dims_to_replicate) { + if (sharding.IsTileMaximal()) { + return sharding; + } + int64 group_count = 1; + for (int64 dim : dims_to_replicate) { + if (sharding.ReplicateOnLastTileDim()) { + CHECK_LT(dim, sharding.tile_assignment().num_dimensions()); + } + group_count *= sharding.tile_assignment().dim(dim); + } + if (group_count == 1) { + return sharding; + } + if (group_count == sharding.NumTiles()) { + return HloSharding::Replicate(); + } + std::vector dim_permutation( + sharding.tile_assignment().num_dimensions()); + std::iota(dim_permutation.begin(), dim_permutation.end(), 0); + absl::c_sort(dim_permutation, [&](const int64 a, const int64 b) { + return absl::c_linear_search(dims_to_replicate, a) < + absl::c_linear_search(dims_to_replicate, b); + }); + auto transposed = TransposeSharding(sharding, dim_permutation); + auto new_tile = transposed.tile_assignment(); + std::vector new_tile_shape( + sharding.tile_assignment().dimensions().begin(), + sharding.tile_assignment().dimensions().end()); + for (int64 dim : dims_to_replicate) { + new_tile_shape[dim] = 1; + } + if (sharding.ReplicateOnLastTileDim()) { + new_tile_shape.back() *= group_count; + } else { + new_tile_shape.push_back(group_count); + } + new_tile.Reshape(new_tile_shape); + return HloSharding::PartialTile(new_tile); +} + +HloSharding RemoveShapeDimensions(const HloSharding& sharding, + const std::vector& dims_to_remove) { + if (sharding.IsTileMaximal() || dims_to_remove.empty()) { + return sharding; + } + std::vector new_tile_shape; + new_tile_shape.reserve(sharding.tile_assignment().num_dimensions() - + dims_to_remove.size()); + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (absl::c_linear_search(dims_to_remove, i)) { + CHECK_EQ(sharding.tile_assignment().dim(i), 1); + } else { + new_tile_shape.push_back(sharding.tile_assignment().dim(i)); + } + } + auto new_tile = sharding.tile_assignment(); + new_tile.Reshape(new_tile_shape); + return sharding.ReplicateOnLastTileDim() ? HloSharding::PartialTile(new_tile) + : HloSharding::Tile(new_tile); +} + } // namespace hlo_sharding_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.h b/tensorflow/compiler/xla/service/hlo_sharding_util.h index 562f6d1420d..ce19d8c7a19 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.h @@ -127,6 +127,26 @@ HloSharding ScatterEffectiveIndexSharding(const HloSharding& index_sharding, HloSharding ScatterEffectiveDataSharding(const HloSharding& data_sharding, const HloInstruction& hlo); +// Returns an output sharding of gather by passing through the data operand's +// sharding. +absl::optional GatherOutputShardingFromDataOperand( + const HloSharding& data_operand_sharding, const HloInstruction& hlo); + +// Returns a data operand sharding of gather by passing through the output's +// sharding. +absl::optional GatherDataOperandShardingFromOutput( + const HloSharding& output_sharding, const HloInstruction& hlo); + +// Returns an output sharding of scatter by passing through the update operand's +// sharding. +absl::optional ScatterOutputShardingFromUpdate( + const HloSharding& update_sharding, const HloInstruction& hlo); + +// Returns an update operand sharding of scatter by passing through the output's +// sharding. +absl::optional ScatterUpdateShardingFromOutput( + const HloSharding& output_sharding, const HloInstruction& hlo); + // Returns an identity value and an HloOpcode for reduce computation of scatter // instruction. // - If computation is add/or, return 0/false with corresponding op code; @@ -143,6 +163,17 @@ IdentityValueAndHloOpcodeForScatterReduceComputation( std::vector DevicesForSharding( const HloSharding& sharding, const std::vector& available_devices); +// Returns a sharding that replicates data across devices along the given +// dimensions in the original sharding. +HloSharding PartiallyReplicateTiledShardingOnDims( + const HloSharding& sharding, const std::vector& dims_to_replicate); + +// Returns a sharding the removes given tile dimensions. +// +// Precondition: if not tile maximal, the size of each tile dimension must be 1. +HloSharding RemoveShapeDimensions(const HloSharding& sharding, + const std::vector& dims_to_remove); + } // namespace hlo_sharding_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc index 02496c75965..08f136b2e45 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util_test.cc @@ -76,6 +76,20 @@ TEST(HloShardingUtilTest, ReshapeShardingTiledSplit) { EXPECT_EQ(result.value(), output_sharding); } +TEST(HloShardingUtilTest, ReshapeShardingTiledSplit2) { + Shape input_shape = ShapeUtil::MakeShape(F32, {16, 7}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 4, 7}); + Array2D tile(16, 1); + tile.FillIota(0); + HloSharding input_sharding = HloSharding::Tile(tile); + tile.Reshape({4, 4, 1}); + HloSharding output_sharding = HloSharding::Tile(tile); + absl::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + TEST(HloShardingUtilTest, ReshapeShardingTiledSplitThenMerge) { Shape input_shape = ShapeUtil::MakeShape(F32, {16, 4, 7}); Shape output_shape = ShapeUtil::MakeShape(F32, {4, 16, 7}); diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index a1150ae299d..a721aabef76 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -57,6 +57,11 @@ struct HloPosition { (instruction->unique_id() == other.instruction->unique_id() && index < other.index); } + + template + friend H AbslHashValue(H h, const HloPosition& pos) { + return H::combine(std::move(h), pos.instruction->Hash(), pos.index); + } }; std::ostream& operator<<(std::ostream& out, const HloPosition& position); @@ -81,6 +86,12 @@ struct HloUse { } bool operator!=(const HloUse& other) const { return !(*this == other); } + + template + friend H AbslHashValue(H h, const HloUse& use) { + return H::combine(std::move(h), use.instruction, use.operand_index, + use.operand_number); + } }; std::ostream& operator<<(std::ostream& out, const HloUse& use); @@ -240,7 +251,8 @@ std::ostream& operator<<(std::ostream& out, const HloValueSet& hlo_value); // hold multiple HloValueSets. class InstructionValueSet : public ShapeTree { public: - InstructionValueSet(const Shape& shape) : ShapeTree(shape) {} + explicit InstructionValueSet(const Shape& shape) + : ShapeTree(shape) {} // Sets this value set to the union of the given value sets. Returns whether // this value set changed. diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 62b0d98418c..d395fddcc5d 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -670,14 +670,6 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { - // Bitcasts are not allowed to change the element type. - if (bitcast->operand(0)->shape().element_type() != - bitcast->shape().element_type()) { - return InternalError( - "Bitcast can not change the element type from %s to %s", - PrimitiveType_Name(bitcast->operand(0)->shape().element_type()), - PrimitiveType_Name(bitcast->shape().element_type())); - } if (layout_sensitive_ && shape_size_function_(bitcast->shape()) != shape_size_function_(bitcast->operand(0)->shape())) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index d9709c50df9..1f71c9586d5 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -540,24 +540,6 @@ TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { HasSubstr("Instruction shouldn't change layouts")); } -TEST_F(HloVerifierTest, BitcastCanNotChangeElementType) { - const char* const hlo_string = R"( - HloModule Module - - ENTRY BitcastCanNotChangeElementType { - constant.0 = f32[2] constant({0.0, 0.0}) - ROOT bitcast = s32[2] bitcast(constant.0) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnUnverifiedModule(hlo_string)); - - auto status = verifier().Run(module.get()).status(); - ASSERT_FALSE(status.ok()); - EXPECT_THAT(status.error_message(), - HasSubstr("Bitcast can not change the element type")); -} - TEST_F(HloVerifierTestLayoutSensitive, BitcastNeedsSameNumberOfElements) { const char* const hlo_string = R"( HloModule Module diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 7a4eefc1ab6..3444d4cae42 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -34,6 +34,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:cholesky_expander", + "//tensorflow/compiler/xla/service:comparison_expander", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:custom_call_target_registry", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 1649be2ca8f..a059482d832 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/cholesky_expander.h" +#include "tensorflow/compiler/xla/service/comparison_expander.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" @@ -81,6 +82,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), diff --git a/tensorflow/compiler/xla/service/interpreter/executable_base.cc b/tensorflow/compiler/xla/service/interpreter/executable_base.cc index 4b020ea2d32..4b6a8aa5202 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable_base.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable_base.cc @@ -81,8 +81,17 @@ StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( for (int64 i = 0; i < computation->num_parameters(); ++i) { const auto& expected_shape = computation->parameter_instruction(i)->shape(); const auto& actual_shape = argument_buffers[i].on_device_shape(); - if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape, - actual_shape)) { + bool shape_match = true; + if (expected_shape.is_dynamic()) { + if (!ShapeUtil::DynamicArrayShapeIsCompatible(actual_shape, + expected_shape)) { + shape_match = false; + } + } else if (!Shape::Equal().MinorToMajorOnlyInLayout()(expected_shape, + actual_shape)) { + shape_match = false; + } + if (!shape_match) { return InvalidArgument( "Shape mismatch on parameter %d. Expected %s, but was %s.", i, ShapeUtil::HumanStringWithLayout(expected_shape), @@ -100,11 +109,18 @@ StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( TF_ASSIGN_OR_RETURN(Literal arg_literal, transfer_manager->TransferLiteralFromDevice( run_options->stream(), argument_buffers[p])); + const auto& expected_shape = computation->parameter_instruction(p)->shape(); + if (expected_shape.is_dynamic()) { + // Expand the input literal to expected shape. + arg_literal = arg_literal.ToBoundedDynamic(expected_shape); + } arg_literals.push_back(std::move(arg_literal)); } TF_ASSIGN_OR_RETURN(Literal result_literal, Evaluate(*computation, arg_literals)); + // Shrink the generated dynamic shape into static shape. + result_literal = result_literal.ToStatic(); // Transform the result literal back into a ShapedBuffer. TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result_buffers, diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 3c48668e742..bea0f1fb93c 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1357,6 +1357,20 @@ Status LayoutAssignment::PropagateOperandConstraint( // Propagate layouts between operands of the same instruction. This is a // constraint on non-layout-changing instructions. if (!instruction_can_change_layout_func_(user)) { + // Only propgate the layout of the largest concatenate operand. + if (user->opcode() == HloOpcode::kConcatenate) { + for (int64 operand_no = 0; operand_no < user->operand_count(); + ++operand_no) { + const HloInstruction* sibling = user->operand(operand_no); + if (sibling == operand) { + continue; + } + if (sibling->shape().dimensions(user->concatenate_dimension()) > + operand->shape().dimensions(user->concatenate_dimension())) { + return Status::OK(); + } + } + } // Make sure all siblings have the same layout as the operand. for (int64 operand_no = 0; operand_no < user->operand_count(); ++operand_no) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index f96c985da71..33121635b0b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -54,9 +54,7 @@ string SanitizeConstantName(const HloInstruction& instr) { return instr_name; } -string ConstantBufferAllocationToGlobalName( - const BufferAllocation& allocation) { - const HloInstruction& instr = InstrForConstantBufferAllocation(allocation); +string ConstantHloToGlobalName(const HloInstruction& instr) { string instr_name = instr.name(); // Check that names are sanitized and stored in the HLO instructions // before constant buffer allocation. @@ -64,6 +62,11 @@ string ConstantBufferAllocationToGlobalName( return absl::StrCat("buffer_for_", instr_name); } +string ConstantBufferAllocationToGlobalName( + const BufferAllocation& allocation) { + return ConstantHloToGlobalName(InstrForConstantBufferAllocation(allocation)); +} + const Literal& LiteralForConstantAllocation( const BufferAllocation& allocation) { return InstrForConstantBufferAllocation(allocation).literal(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h index 03e98a66900..2e2d3bf0b48 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h @@ -24,6 +24,9 @@ namespace llvm_ir { // name of the corresponding constant buffer. In particular, it replaces . and // - with _. string SanitizeConstantName(const HloInstruction& instr); + +string ConstantHloToGlobalName(const HloInstruction& instr); + // In XLA:GPU we map constant buffer allocations to globals in the generated // LLVM IR. This function gives us the name of the global variable a constant // buffer is mapped to. Not used on XLA:CPU. diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index daf98478194..d89a9c2e0a5 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -62,10 +62,11 @@ void EmitTuple(const IrArray& tuple, absl::Span operands, llvm::IRBuilder<>* b) { llvm::Module* module = getModuleFromBuilder(b); for (size_t i = 0; i < operands.size(); ++i) { + auto* cast = + b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)); auto* store = b->CreateStore( - b->CreatePointerCast(operands[i], PrimitiveTypeToIrType(TUPLE, module)), - b->CreateInBoundsGEP(tuple.GetBasePointer(), - {b->getInt64(0), b->getInt64(i)})); + cast, b->CreateInBoundsGEP(tuple.GetBasePointer(), + {b->getInt64(0), b->getInt64(i)})); tuple.AnnotateLoadStoreInstructionWithMetadata(store); } } diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index c80646e0c70..5def5bbe9db 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -114,6 +114,7 @@ ExecutionOptions CreateExecutionOptions( execution_options.set_num_partitions(build_options.num_partitions()); execution_options.set_use_spmd_partitioning( build_options.use_spmd_partitioning()); + execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo()); if (build_options.has_device_assignment()) { TF_CHECK_OK(build_options.device_assignment().Serialize( execution_options.mutable_device_assignment())); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 4b26fba3bab..c5ae0573bed 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -199,6 +199,12 @@ float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory( } float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed( + const HloInstruction& instruction) const { + return std::max(GetInstructionElapsedDueToCompute(instruction), + GetInstructionElapsedDueToMemory(instruction)); +} + +float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory( const HloInstruction& instruction, absl::optional operand_in_alternate_mem, bool output_in_alternate_mem) const { @@ -229,6 +235,11 @@ int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime( return std::min(start_time + min_overlap_count_, latest_end_time); } +int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime( + const HloUse& use, int64 start_time, int64 end_time) const { + return end_time - min_overlap_count_; +} + void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use, int64 start_time, int64 end_time) { @@ -258,12 +269,15 @@ std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString( CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( const MemorySpaceAssignmentCostAnalysis& cost_analysis, float min_async_copy_to_overlap_ratio, - float max_async_copy_to_overlap_ratio) + float max_async_copy_to_overlap_ratio, + float preferred_async_copy_to_overlap_ratio) : while_nest_level_( cost_analysis.hlo_live_range().instruction_schedule().size(), 0), cost_analysis_(cost_analysis), min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio), - max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) { + max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio), + preferred_async_copy_to_overlap_ratio_( + preferred_async_copy_to_overlap_ratio) { instruction_schedule_ = &cost_analysis_.hlo_live_range().instruction_schedule(); @@ -277,12 +291,6 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( // To avoid double counting, don't include the elapsed time of while and // conditional HLOs. const HloInstruction* instruction = instruction_and_logical_time.first; - if (instruction->opcode() == HloOpcode::kWhile || - instruction->opcode() == HloOpcode::kConditional) { - continue; - } - float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( - *instruction_and_logical_time.first); int64 logical_time = instruction_and_logical_time.second; if (logical_time >= instructions_elapsed_time.size()) { instructions_elapsed_time.resize(logical_time + 1, 0.0); @@ -291,6 +299,12 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( int nest_level = cost_analysis_.CalculateWhileLoopNestLevel( instruction_and_logical_time.first); while_nest_level_[logical_time] = nest_level; + if (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kConditional) { + continue; + } + float elapsed_time = cost_analysis_.GetInstructionElapsed( + *instruction_and_logical_time.first); instructions_elapsed_time[logical_time] = elapsed_time * tensorflow::MathUtil::IPow(kWhileExecutionCount, nest_level); @@ -346,6 +360,49 @@ int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( return end_time; } +int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( + const HloUse& use, int64 start_time, int64 end_time) const { + const Shape& shape = ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), use.operand_index); + // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_. + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); + // Estimate the time we would save by having this op in alternate memory. + float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); + float elapsed_time_in_alternate_mem = + cost_analysis_.GetInstructionElapsedInAlternateMemory( + *use.instruction, use.operand_number, + /*output_in_alternate_mem=*/false); + float inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; + int end_nest_level = while_nest_level_[end_time]; + + // Find the latest time we're allowed to start prefetching. + float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed; + int latest_prefetch_time; + for (latest_prefetch_time = end_time - 1; + latest_prefetch_time >= start_time && + (while_nest_level_[latest_prefetch_time] != end_nest_level || + min_interval > + GetLogicalIntervalElapsed(latest_prefetch_time, end_time) + + inst_elapsed_reduction); + --latest_prefetch_time) { + } + + return latest_prefetch_time; +} + +int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime( + int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const { + // Iterate towards the beginning until we find a suitable end time that is the + // same while nest level as the original prefetch end time. + int64 original_nest_level = while_nest_level_[original_prefetch_end_time]; + int64 new_prefetch_end_time; + for (new_prefetch_end_time = proposed_prefetch_end_time; + while_nest_level_[new_prefetch_end_time] != original_nest_level; + --new_prefetch_end_time) { + } + return new_prefetch_end_time; +} + void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, int64 start_time, int64 end_time) { @@ -355,52 +412,100 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape); // Estimate the time we would save by having this op in alternate memory. float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); - float elapsed_time_in_alternate_mem = cost_analysis_.GetInstructionElapsed( - *use.instruction, use.operand_number); + float elapsed_time_in_alternate_mem = + cost_analysis_.GetInstructionElapsedInAlternateMemory( + *use.instruction, use.operand_number, + /*output_in_alternate_mem=*/false); inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; end_logical_time_ = end_time; - earliest_start_logical_time_ = start_time; - int end_nest_level = while_nest_level_[end_time]; - // Find the latest time we're allowed to start prefetching. If the start and - // end nest levels differe look for an earlier prefetch start. - for (current_logical_prefetch_time_ = end_time - 1; - current_logical_prefetch_time_ > start_time && - (while_nest_level_[current_logical_prefetch_time_] != end_nest_level || - min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ > - GetLogicalIntervalElapsed(current_logical_prefetch_time_, - end_logical_time_) + - inst_elapsed_reduction_); - --current_logical_prefetch_time_) { + int end_nest_level = while_nest_level_[end_logical_time_]; + + // Find the latest time we're allowed to start prefetching. + float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_; + latest_prefetch_time_ = LatestPrefetchStartTime(use, start_time, end_time); + + // Find the earliest time we're allowed to start prefetching. + float max_interval = max_async_copy_to_overlap_ratio_ * + max_overlap_multiplier_ * async_copy_elapsed_; + for (earliest_prefetch_time_ = start_time; + earliest_prefetch_time_ <= end_logical_time_ && + (while_nest_level_[earliest_prefetch_time_] != end_nest_level || + max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_, + end_logical_time_)); + ++earliest_prefetch_time_) { } + if (earliest_prefetch_time_ > latest_prefetch_time_) { + // There is no available prefetch interval for the given start and end + // times. Set the iterators accordingly to ensure Done() returns true. + increasing_prefetch_time_iterator_ = earliest_prefetch_time_; + decreasing_prefetch_time_iterator_ = latest_prefetch_time_; + CHECK(Done()); + return; + } + + // Between the earliest and latest prefetch interval, find the interval + // closest to the preferred interval and start iterating from there. + int64 starting_prefetch_time = earliest_prefetch_time_; + float preferred_interval = + preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_; + float best_interval = + GetLogicalIntervalElapsed(earliest_prefetch_time_, end_logical_time_); + for (int64 prefetch_time = earliest_prefetch_time_ + 1; + prefetch_time <= latest_prefetch_time_; ++prefetch_time) { + float interval = + GetLogicalIntervalElapsed(prefetch_time, end_logical_time_); + if (while_nest_level_[prefetch_time] == end_nest_level && + std::abs(preferred_interval - interval) < + std::abs(preferred_interval - best_interval)) { + best_interval = interval; + starting_prefetch_time = prefetch_time; + } + } + VLOG(4) << "Interval min/max/preferred = " << min_interval << " " + << max_interval << " " << preferred_interval + << " prefetch time earliest/latest/starting = " + << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " " + << starting_prefetch_time; + + increasing_prefetch_time_iterator_ = starting_prefetch_time; + decreasing_prefetch_time_iterator_ = starting_prefetch_time; + using_increasing_prefetch_time_iterator_ = true; + // Since both iterators start at the same position, call Next() once to + // advance one of the iterators. + Next(); } int64 CostAnalysisPrefetchIntervalPicker::Next() { CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " "Done() is false"; - int64 prefetch_time = current_logical_prefetch_time_; - if (!Done()) { - --current_logical_prefetch_time_; + if (using_increasing_prefetch_time_iterator_) { + int64 prefetch_time = increasing_prefetch_time_iterator_++; + while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ && + while_nest_level_[increasing_prefetch_time_iterator_] != + while_nest_level_[end_logical_time_]) { + ++increasing_prefetch_time_iterator_; + } + if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) { + using_increasing_prefetch_time_iterator_ = false; + } + return prefetch_time; + } else { + int64 prefetch_time = decreasing_prefetch_time_iterator_--; + while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ && + while_nest_level_[decreasing_prefetch_time_iterator_] != + while_nest_level_[end_logical_time_]) { + --decreasing_prefetch_time_iterator_; + } + if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) { + using_increasing_prefetch_time_iterator_ = true; + } + return prefetch_time; } - // If the prefetch start and end times differ, look for an earlier prefetch - // start. - while (!Done() && while_nest_level_[current_logical_prefetch_time_] != - while_nest_level_[end_logical_time_]) { - --current_logical_prefetch_time_; - } - return prefetch_time; } bool CostAnalysisPrefetchIntervalPicker::Done() const { - if (current_logical_prefetch_time_ < earliest_start_logical_time_) { - return true; - } - float logical_interval_elapsed = GetLogicalIntervalElapsed( - current_logical_prefetch_time_, end_logical_time_); - return (max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ * - async_copy_elapsed_ < - logical_interval_elapsed) || - (min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ > - logical_interval_elapsed + inst_elapsed_reduction_); + return increasing_prefetch_time_iterator_ > latest_prefetch_time_ && + decreasing_prefetch_time_iterator_ < earliest_prefetch_time_; } void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) { @@ -440,13 +545,16 @@ float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( } std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { + int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_ + ? increasing_prefetch_time_iterator_ + : decreasing_prefetch_time_iterator_; float logical_interval_elapsed = GetLogicalIntervalElapsed( - current_logical_prefetch_time_, end_logical_time_); + current_logical_prefetch_time, end_logical_time_); return absl::StrCat( "Async copy elapsed (s) = ", async_copy_elapsed_, ", inst elapsed reduction (s) = ", inst_elapsed_reduction_, ", logical interval elapsed (s) = ", logical_interval_elapsed, - ", interval = (", current_logical_prefetch_time_, ", ", end_logical_time_, + ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_, ")"); } @@ -466,6 +574,24 @@ CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( return cost_analysis_.GetMemoryBoundedness(interval); } +bool MemorySpaceAssignment::Allocation::operator==( + const MemorySpaceAssignment::Allocation& other) const { + return defining_position() == other.defining_position() && + uses() == other.uses() && memory_space() == other.memory_space() && + chunk() == other.chunk() && start_time() == other.start_time() && + end_time() == other.end_time() && + is_copy_allocation() == other.is_copy_allocation(); +} + +bool MemorySpaceAssignment::CopyAllocation::operator==( + const MemorySpaceAssignment::CopyAllocation& other) const { + return static_cast(*this) == + static_cast(other) && + copy_done_schedule_before() == other.copy_done_schedule_before() && + copy_start_schedule_after() == other.copy_start_schedule_after() && + copy_start() == other.copy_start() && copy_done() == other.copy_done(); +} + std::string MemorySpaceAssignment::AllocationValue::ToString() const { std::string out = absl::StrCat("computation = ", computation()->name()); absl::StrAppend(&out, "\n position:\n"); @@ -484,7 +610,9 @@ std::string MemorySpaceAssignment::AllocationValue::ToShortString() const { } void AlternateMemoryBestFitHeap::CreateAllocationValues( - const HloValue* value, std::vector* allocation_values) { + const AlternateMemoryBestFitHeap::BufferInterval& buffer_interval, + std::vector& allocation_values) const { + const HloValue* value = buffer_interval.buffer; VLOG(3) << "Creating AllocationValues for: " << value->ToString(); // Find and sort all non-trivial (excluding GTE, Tuple, and bitcast) @@ -512,10 +640,10 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( // Create an AllocationValue for each non-trivial position. absl::flat_hash_set computations; - int beginning_idx = allocation_values->size(); + int beginning_idx = allocation_values.size(); for (int i = 0; i < positions.size(); ++i) { const HloPosition& position = positions.at(i); - allocation_values->emplace_back(value, position); + allocation_values.emplace_back(value, position, buffer_interval.size); } std::vector uses(value->uses()); @@ -536,8 +664,8 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( HloComputation* use_computation = use.instruction->parent(); AllocationValue* last_allocation_value = nullptr; - for (int i = beginning_idx; i < allocation_values->size(); ++i) { - AllocationValue* allocation_value = &allocation_values->at(i); + for (int i = beginning_idx; i < allocation_values.size(); ++i) { + AllocationValue* allocation_value = &allocation_values.at(i); if (allocation_value->computation() == use_computation && instruction_schedule.at( allocation_value->defining_position().instruction) < use_time) { @@ -548,9 +676,9 @@ void AlternateMemoryBestFitHeap::CreateAllocationValues( last_allocation_value->AddUse(use, use_time); } - for (int i = beginning_idx; i < allocation_values->size(); ++i) { + for (int i = beginning_idx; i < allocation_values.size(); ++i) { VLOG(3) << "Created allocation value: " - << allocation_values->at(i).ToString(); + << allocation_values.at(i).ToString(); } } @@ -774,7 +902,7 @@ void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( std::vector use_names; use_times.reserve(uses.size()); use_names.reserve(uses.size()); - for (auto use : uses) { + for (const auto& use : uses) { use_times.push_back(use.first); use_names.push_back(use.second); } @@ -794,27 +922,27 @@ void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( } void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + const AllocationValue& value, const MemorySpaceAssignment::Allocation& allocation, - std::string* debug_str) const { + std::string& debug_str) const { // Columns in allocation information: // buffer_id: int. This value can be used the match with buffer info. // size: int. In bytes. // offset: int. In bytes. // start_time: int. Logical start time of the allocation. // end_time: int. Logical end time of the allocation. - if (debug_str->empty()) { + if (debug_str.empty()) { // Append the column names. - absl::StrAppend(debug_str, "buffer_id,size,offset,start_time,end_time\n"); + absl::StrAppend(&debug_str, "buffer_id,size,offset,start_time,end_time\n"); } if (allocation.memory_space() == MemorySpace::kAlternate) { const HloBuffer& buffer = - alias_analysis_.GetBufferContainingValue(*interval.buffer); - absl::StrAppend(debug_str, buffer.id(), ","); - absl::StrAppend(debug_str, interval.size, ","); - absl::StrAppend(debug_str, allocation.chunk().offset, ","); - absl::StrAppend(debug_str, allocation.start_time(), ","); - absl::StrAppend(debug_str, allocation.end_time(), "\n"); + alias_analysis_.GetBufferContainingValue(*value.value()); + absl::StrAppend(&debug_str, buffer.id(), ","); + absl::StrAppend(&debug_str, value.size(), ","); + absl::StrAppend(&debug_str, allocation.chunk().offset, ","); + absl::StrAppend(&debug_str, allocation.start_time(), ","); + absl::StrAppend(&debug_str, allocation.end_time(), "\n"); } } @@ -845,6 +973,16 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } } + for (const auto& interval : sorted_buffer_intervals) { + auto colocated_intervals = GetSortedColocatedIntervals(interval); + if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { + // Increment the reserved part of alternate memory so that it is not + // available for other buffers. + reserved_in_bytes_ += options_.size_fn(*interval.buffer); + } + } + VLOG(2) << "Total reserved bytes = " << reserved_in_bytes_; + for (auto& interval : sorted_buffer_intervals) { if (!interval.need_allocation) { continue; @@ -872,8 +1010,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { VLOG(3) << "Interval " << interval.buffer->ToShortString() - << " is reserved in the alternate memory. Total reserved bytes = " - << reserved_in_bytes_; + << " is reserved in the alternate memory."; for (const BufferInterval* colocated_interval : colocated_intervals) { const HloValue* value = colocated_interval->buffer; // Color all of the aliased reserved buffers here because reserved @@ -889,10 +1026,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { options_.alternate_memory_space); } } - // Increment the reserved part of alternate memory so that it is not - // available for other buffers. Since all colocated intervals should have - // the same size, just use the first one. - reserved_in_bytes_ += options_.size_fn(*colocated_intervals[0]->buffer); continue; } @@ -913,16 +1046,43 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { AppendBufferInfoDebugString(interval, &buffer_info_str_); + std::vector allocation_values; + CreateAllocationValuesFromColocatedIntervals(colocated_intervals, + allocation_values); + // Retry allocating this value with larger limits if allocation fails. for (int retry_number = 0; retry_number < options_.max_retries; retry_number++) { - final_retry_ = (retry_number == options_.max_retries - 1); + bool final_retry = (retry_number == options_.max_retries - 1); options_.prefetch_interval_picker->SetRetryNumber(retry_number); - bool success = AllocateColocatedIntervals(colocated_intervals); - if (success) { + Result result = + AllocateAllocationValues(absl::MakeSpan(allocation_values)); + VLOG(2) << "Allocation result = " + << absl::StrFormat("%x", static_cast(result)); + if (result_requires_uncommit(result) || + (!final_retry && result_failed_because_of_async_copy(result))) { + UncommitPendingChunks(absl::MakeSpan(allocation_values)); + VLOG(2) << "Couldn't allocate. Retry number " << retry_number; + } else if (result_is(result, Result::kFailOutOfMemory) && + num_repacks_ < options_.max_repacks) { + UncommitPendingChunks(absl::MakeSpan(allocation_values)); + ++num_repacks_; + CHECK_NE(options_.repacker, nullptr); + std::vector repack_allocation_blocks; + ExportAllocationsForRepacking(repack_allocation_blocks); + VLOG(2) << "Repacking."; + auto repack_status = + options_.repacker->Repack(absl::MakeSpan(repack_allocation_blocks)); + CHECK_EQ(repack_status.status(), Status::OK()); + VLOG(2) << "Repack complete. Modified = " << *repack_status; + if (*repack_status) { + ImportRepackedAllocations(absl::MakeSpan(repack_allocation_blocks)); + --retry_number; + } + } else { + FinalizeAllocations(absl::MakeSpan(allocation_values)); break; } - VLOG(2) << "Couldn't allocate. Retry number " << retry_number; } } @@ -935,9 +1095,10 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { return result_; } -bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( - const std::vector& - colocated_intervals) { +void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals( + absl::Span + colocated_intervals, + std::vector& allocation_values) { // TODO(berkin): For now, place the phi values due to conditionals in // default memory. for (const BufferInterval* colocated_interval : colocated_intervals) { @@ -958,11 +1119,15 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( } // Create AllocationValues for all the colocated intervals. - std::vector allocation_values; for (const auto& colocated_interval : colocated_intervals) { - CreateAllocationValues(colocated_interval->buffer, &allocation_values); + CreateAllocationValues(*colocated_interval, allocation_values); } FindAliases(&allocation_values); +} + +AlternateMemoryBestFitHeap::Result +AlternateMemoryBestFitHeap::AllocateAllocationValues( + absl::Span allocation_values) { const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); // Data structure to contain the preferred offset for a given computation. @@ -971,8 +1136,8 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( absl::flat_hash_map preferred_offset_for_computation; - bool allocation_success = true; - for (auto& allocation_value : allocation_values) { + Result result = Result::kSuccess; + for (AllocationValue& allocation_value : allocation_values) { int64 definition_time = instruction_schedule.at(allocation_value.defining_instruction()); @@ -1086,20 +1251,19 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( request.start_time = std::min(definition_time, use_time); request.end_time = use_time; request.latest_prefetch_time = latest_prefetch_time; - request.size = colocated_intervals[0]->size; + request.size = allocation_value.size(); request.allow_no_copy_alternate_mem_allocation = allow_no_copy_alternate_mem_allocation; request.earliest_prefetch_time = earliest_prefetch_time; request.preferred_offset = preferred_offset; request.use = &use; request.allocation_value = &allocation_value; - if (!AllocateSegment(request)) { + result_mark(AllocateSegment(request), result); + if (result_requires_uncommit(result)) { // If the allocation finding failed (e.g., due to running out of // asynchronous copies), then fall back to allocating the buffer // entirely in the default memory. - UncommitPendingChunks(); - allocation_success = false; - break; + return result; } // If there are multiple uses, they can try using the memory allocation @@ -1125,24 +1289,8 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( aliased_allocation->chunk().offset; } } - if (!allocation_success) { - break; - } } - if (allocation_success) { - for (AllocationValue& allocation_value : allocation_values) { - for (auto& allocation : *allocation_value.allocation_sequence()) { - AppendAllocationInfoDebugString(*colocated_intervals[0], *allocation, - &allocation_info_str_); - allocations_->push_back(std::move(allocation)); - } - } - } - - pending_chunks_.clear(); - pending_async_copies_.clear(); - pending_required_assignments_.clear(); - return allocation_success; + return result; } bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) { @@ -1162,15 +1310,21 @@ void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) { ranges_.erase(copy_it); } -bool AsynchronousCopyOrdering::ViolatesOrdering(int64 start_time, - int64 end_time) const { +absl::optional AsynchronousCopyOrdering::ViolatesOrdering( + int64 start_time, int64 end_time) const { // We allow identical start and end times. It is enough to check for just the // start time in case we find a match in ranges_ because the found value will // either be identical to {start_time, end_time} (and this doesn't violate) or // its start_time will be smaller and end_time will be larger (this violates). auto copy_it = ranges_.find( {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate}); - return copy_it != ranges_.end() && copy_it->start_time != start_time; + if (copy_it != ranges_.end() && copy_it->start_time != start_time) { + VLOG(4) << "Violates ordering: (" << start_time << ", " << end_time + << ") and (" << copy_it->start_time << ", " << copy_it->end_time + << ")"; + return *copy_it; + } + return absl::nullopt; } /*static*/ MemorySpaceAssignment::Allocation* @@ -1228,9 +1382,7 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( allocations_->push_back(std::move(allocation)); } - pending_chunks_.clear(); - pending_async_copies_.clear(); - pending_required_assignments_.clear(); + ClearPendingChunks(); } absl::optional @@ -1407,7 +1559,40 @@ bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory( return false; } -void AlternateMemoryBestFitHeap::UncommitPendingChunks() { +void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking( + std::vector& + allocations) { + for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) { + allocations.push_back(&allocation_block); + } +} + +void AlternateMemoryBestFitHeap::ImportRepackedAllocations( + absl::Span + repacked_allocations) { + interval_tree_ = {}; + for (RepackAllocationBlock* allocation_block : repacked_allocations) { + MemorySpaceAssignment::Allocation* allocation = allocation_block->opaque; + VLOG(3) << "Moved " << allocation->ToString() << ", size " + << allocation->chunk().size << " from " + << allocation_block->initial_offset << " to " + << allocation_block->offset; + allocation_block->opaque->mutable_chunk()->offset = + allocation_block->offset; + interval_tree_.Add(allocation_block->start_time, allocation_block->end_time, + {allocation_block->offset, allocation_block->size}); + allocation_block->initial_offset = allocation_block->offset; + allocation_block->offset = -1; + } +} + +void AlternateMemoryBestFitHeap::UncommitPendingChunks( + absl::Span allocation_values) { + // Clear the allocation sequence of the allocation values so that in case we + // retry allocation after uncommitting. + for (AllocationValue& allocation_value : allocation_values) { + allocation_value.allocation_sequence()->clear(); + } for (const auto& interval_and_chunk : pending_chunks_) { const BufferInterval& interval = interval_and_chunk.first; const Chunk& chunk = interval_and_chunk.second.chunk; @@ -1446,6 +1631,48 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks() { } } } + ClearPendingChunks(); +} + +void AlternateMemoryBestFitHeap::FinalizeAllocations( + absl::Span allocation_values) { + absl::flat_hash_map> + colocation_map; + for (AllocationValue& allocation_value : allocation_values) { + for (auto& allocation : *allocation_value.allocation_sequence()) { + AppendAllocationInfoDebugString(allocation_value, *allocation, + allocation_info_str_); + allocations_->push_back(std::move(allocation)); + MemorySpaceAssignment::Allocation* inserted_allocation = + allocations_->back().get(); + if (inserted_allocation->memory_space() == MemorySpace::kAlternate) { + colocation_map[inserted_allocation->chunk().offset].push_back( + inserted_allocation); + } + } + } + // Assume allocations that received the same offset need to be colocated. + // Export these to repack_allocation_blocks_ so that we can repack them to + // reduce fragmentation. + for (auto& colocation : colocation_map) { + std::vector colocations; + for (MemorySpaceAssignment::Allocation* colocated_allocation : + colocation.second) { + repack_allocation_blocks_.push_back( + {colocated_allocation->start_time(), colocated_allocation->end_time(), + colocated_allocation->chunk().size, /*offset=*/-1, + colocated_allocation->chunk().offset, /*colocations=*/{}, + colocated_allocation}); + colocations.push_back(&repack_allocation_blocks_.back()); + } + for (RepackAllocationBlock* repack_block : colocations) { + repack_block->colocations = colocations; + } + } + ClearPendingChunks(); +} + +void AlternateMemoryBestFitHeap::ClearPendingChunks() { pending_chunks_.clear(); pending_async_copies_.clear(); pending_required_assignments_.clear(); @@ -1461,7 +1688,7 @@ void AlternateMemoryBestFitHeap::AddToPendingChunks( CommitChunk(buffer_interval, chunk_candidate); } -bool AlternateMemoryBestFitHeap::AllocateSegment( +AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment( const AllocationRequest& request) { auto allocation_sequence = request.allocation_value->allocation_sequence(); // start_time == end_time is a special case where the value is consumed @@ -1472,7 +1699,7 @@ bool AlternateMemoryBestFitHeap::AllocateSegment( GetLiveAllocationAt(*allocation_sequence, request.end_time); CHECK_NE(allocation, nullptr); allocation->AddUse(request.use->hlo_use); - return true; + return Result::kSuccess; } const HloPosition& defining_position = @@ -1536,12 +1763,15 @@ bool AlternateMemoryBestFitHeap::AllocateSegment( } } + Result allocation_result = Result::kSuccess; // First try keeping the allocation entirely in the alternate memory. if (required_memory_space_at_start != MemorySpace::kDefault && required_memory_space_at_end != MemorySpace::kDefault && - request.allow_no_copy_alternate_mem_allocation && - AllocateInAlternateMemoryNoCopy(request)) { - return true; + request.allow_no_copy_alternate_mem_allocation) { + allocation_result = AllocateInAlternateMemoryNoCopy(request); + if (allocation_result == Result::kSuccess) { + return Result::kSuccess; + } } auto prev_allocation_it = allocation_sequence->rbegin(); @@ -1560,8 +1790,10 @@ bool AlternateMemoryBestFitHeap::AllocateSegment( (*prev_allocation_it)->defining_position() == defining_position) { // If there was an allocation for this HloValue that was in the alternate // memory space, we also need to perform an eviction. - if (!Evict(request)) { - return false; + Result eviction_result = Evict(request); + if (eviction_result != Result::kSuccess) { + // A non-success eviction requires us to uncommit previous allocations. + return result_mark(Result::kFailRequiresUncommit, eviction_result); } prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) { @@ -1582,31 +1814,28 @@ bool AlternateMemoryBestFitHeap::AllocateSegment( << "Not trying to prefetch because use requires buffer in default mem."; (*prev_allocation_in_default_mem_it)->Extend(request.end_time); (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use); - return true; + return Result::kSuccess; } // Finally, try to prefetch the buffer into alternate memory. - if (Prefetch(request, **prev_allocation_in_default_mem_it)) { - return true; - } - if (!final_retry_ && prefetch_failed_due_to_async_copy_) { - // If prefetching failed due to asynchronous copy and we're not in our final - // try, return false (failure) so that we can retry this interval with - // larger limits. - return false; + Result prefetch_result = + Prefetch(request, **prev_allocation_in_default_mem_it); + if (prefetch_result == Result::kSuccess) { + return Result::kSuccess; } + result_mark(prefetch_result, allocation_result); // If the end assignment was required to be in alternate memory but that // wasn't possible, then this allocation is invalid. if (required_memory_space_at_end == MemorySpace::kAlternate) { - return false; + return result_mark(Result::kFailRequiresUncommit, allocation_result); } // If a copy wasn't inserted, then add this use to the latest allocation in // default memory. (*prev_allocation_in_default_mem_it)->Extend(request.end_time); (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use); - return true; + return allocation_result; } void AlternateMemoryBestFitHeap::AddAsyncCopy( @@ -1667,12 +1896,14 @@ bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies( } } -bool AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering( - int64 start_time, int64 end_time) const { +absl::optional +AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(int64 start_time, + int64 end_time) const { return async_copy_ordering_.ViolatesOrdering(start_time, end_time); } -bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( +AlternateMemoryBestFitHeap::Result +AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( const AllocationRequest& request) { MemorySpaceAssignment::Allocation* prev_allocation = nullptr; bool can_eliminate_copy = false; @@ -1691,7 +1922,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( } if (!can_eliminate_copy) { - return false; + return Result::kFailPrevAllocationNotInAlternateMem; } const HloPosition& defining_position = @@ -1699,7 +1930,7 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( defining_position.shape(), request.start_time + 1, request.end_time)) { - return false; + return Result::kFailLiveRangeTooLong; } BufferInterval alternate_mem_interval; @@ -1778,12 +2009,13 @@ bool AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy( } request.allocation_value->allocation_sequence()->back()->AddUse( request.use->hlo_use); - return true; + return Result::kSuccess; } - return false; + return Result::kFailOutOfMemory; } -bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { +AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict( + const AllocationRequest& request) { CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0); MemorySpaceAssignment::Allocation* prev_allocation = request.allocation_value->allocation_sequence()->back().get(); @@ -1872,13 +2104,62 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { << " and " << hlo_live_range_.flattened_instruction_sequence() .instructions()[eviction_end_time]; - return false; + // return false; + return Result::kFailOutOfAsyncCopies; } } - return true; + // return true; + return Result::kSuccess; } -bool AlternateMemoryBestFitHeap::Prefetch( +int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime( + const AllocationRequest& request, int64 earliest_prefetch_time) const { + int64 prefetch_end_time = request.latest_prefetch_time; + + for (int retry_number = 0; + retry_number < options_.prefetch_copy_done_reorder_max_retries; + ++retry_number) { + int64 latest_prefetch_time = + options_.prefetch_interval_picker->LatestPrefetchStartTime( + request.use->hlo_use, earliest_prefetch_time, prefetch_end_time); + VLOG(4) << "Latest prefetch start time = " << latest_prefetch_time + << ", earliest prefetch start time = " << earliest_prefetch_time + << ", prefetch end time = " << prefetch_end_time; + // Return if we couldn't find a suitable prefetch start time. + if (latest_prefetch_time < earliest_prefetch_time) { + break; + } + + // Return either if there is no other violating asynchronous copy (since we + // don't need to change the prefetch end time) or if the violating + // asynchronous copy ends after the prefetch end time. + auto violating_async_copy = + ViolatesAsyncCopyOrdering(latest_prefetch_time, prefetch_end_time); + if (!violating_async_copy || + violating_async_copy->end_time >= prefetch_end_time) { + break; + } + VLOG(4) << "Violating async copy: (" << violating_async_copy->start_time + << ", " << violating_async_copy->end_time << ")"; + + int64 new_prefetch_end_time = + options_.prefetch_interval_picker->LatestPrefetchEndTime( + prefetch_end_time, violating_async_copy->end_time); + if (new_prefetch_end_time > earliest_prefetch_time) { + VLOG(3) << "Update prefetch end time = " << new_prefetch_end_time; + prefetch_end_time = new_prefetch_end_time; + } else { + VLOG(3) << "Can't update prefetch end time = " << new_prefetch_end_time + << " because earliest prefetch start time = " + << earliest_prefetch_time; + break; + } + } + + return prefetch_end_time; +} + +AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch( const AllocationRequest& request, const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) { // Try partially placing the buffer in the alternate space. The time that is @@ -1899,9 +2180,11 @@ bool AlternateMemoryBestFitHeap::Prefetch( earliest_prefetch_time = std::max(earliest_prefetch_time, *request.earliest_prefetch_time); } - options_.prefetch_interval_picker->Begin(request.use->hlo_use, - earliest_prefetch_time, - request.latest_prefetch_time); + int64 prefetch_end_time = + FindPrefetchEndTime(request, earliest_prefetch_time); + + options_.prefetch_interval_picker->Begin( + request.use->hlo_use, earliest_prefetch_time, prefetch_end_time); VLOG(3) << "Trying prefetch picker = " << options_.prefetch_interval_picker->ToDebugString(); @@ -1910,33 +2193,30 @@ bool AlternateMemoryBestFitHeap::Prefetch( BufferInterval alternate_mem_interval; alternate_mem_interval.buffer = request.allocation_value->value(); alternate_mem_interval.size = request.size; - // If any of the prefetch intervals couldn't be used due to number of - // outstanding async copy limit or async copy ordering, set - // prefetch_failed_due_to_async_copy_. - prefetch_failed_due_to_async_copy_ = false; // While uses might be allowed to have additional outstanding prefetches. int64 extra_async_copy_limit = request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile ? options_.while_use_extra_outstanding_prefetch_limit : 0; + Result result = Result::kSuccess; while (!options_.prefetch_interval_picker->Done()) { alternate_mem_interval.start = options_.prefetch_interval_picker->Next(); - CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time); + CHECK_LT(alternate_mem_interval.start, prefetch_end_time); VLOG(4) << "Trying alternate memory allocation (" << alternate_mem_interval.start << ", " << request.end_time << ")"; // If this additional asynchronous copy would violate the limit, try a // different interval. if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start, - request.latest_prefetch_time)) { + prefetch_end_time)) { VLOG(4) << "This would violate asynchronous copy ordering."; - prefetch_failed_due_to_async_copy_ = true; + result_mark(Result::kFailViolatesAsyncCopyOrdering, result); continue; } if (ViolatesMaximumOutstandingAsyncCopies( - alternate_mem_interval.start, request.latest_prefetch_time, + alternate_mem_interval.start, prefetch_end_time, /*is_prefetch=*/true, extra_async_copy_limit)) { VLOG(4) << "This would violate the outstanding async copy limit."; - prefetch_failed_due_to_async_copy_ = true; + result_mark(Result::kFailOutOfAsyncCopies, result); continue; } @@ -1955,16 +2235,22 @@ bool AlternateMemoryBestFitHeap::Prefetch( AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate, chunk_candidate->chunk, alternate_mem_interval.start, - request.end_time, request.latest_prefetch_time, + request.end_time, prefetch_end_time, request.allocation_value->allocation_sequence()); request.allocation_value->allocation_sequence()->back()->AddUse( request.use->hlo_use); - prefetch_failed_due_to_async_copy_ = false; - return true; + return Result::kSuccess; } + result_mark(Result::kFailOutOfMemory, result); + } + // If we didn't consider any prefetch intervals, then the live range was too + // short. + if (result == Result::kSuccess) { + return Result::kFailLiveRangeTooShort; + } else { + return result; } - return false; } absl::optional @@ -2113,6 +2399,9 @@ bool IsCrossProgramPrefetchCandidate( return value.instruction()->parent() == value.instruction()->GetModule()->entry_computation() && value.instruction()->opcode() == HloOpcode::kParameter && + (!value.shape().has_layout() || + value.shape().layout().memory_space() != + options.alternate_memory_space) && value.index().size() == 1 && value.shape().IsArray() && !value.uses().empty() && options.size_fn(value) <= options.max_size_in_bytes && @@ -2381,6 +2670,8 @@ Status MemorySpaceAssignment::CopyAllocation::Process( HloOpcode::kCopyStart, producing_instruction)); copy_done_ = computation->AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); + VLOG(4) << "Created " << copy_start_->name() + << " for position: " << defining_position().ToString(); // Update the allocation position with the copy done instruction so that if // there are further copies from it, it can find the correct position. defining_position_ = HloPosition{copy_done_, {}}; @@ -2840,18 +3131,23 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { } } - if (last_use_instruction && - last_use_instruction->opcode() == HloOpcode::kConditional) { + std::function + split_conditional_buffer; + split_conditional_buffer = [&](const HloInstruction* use_instruction, + int64 start_time, int64 end_time, + absl::string_view indent_string) { // Special case when verifying conditional: we internally split the use // of alternate memory in conditionals, so fish them out from the // conditionals. - VLOG(3) << " Splitting conditional buffer: " << buffer.ToString() - << " value: " << value->ToShortString() << ": (" - << time_bound.start << ", " << time_bound.end - << ") off: " << chunk.offset << ", size: " << chunk.size; - int64 earliest_computation_start_time = time_bound.end; + VLOG(3) << indent_string + << "Splitting conditional buffer: " << buffer.ToString() + << " value: " << value->ToShortString() << ": (" << start_time + << ", " << end_time << ") off: " << chunk.offset + << ", size: " << chunk.size; + int64 earliest_computation_start_time = end_time; for (const HloComputation* called_computation : - last_use_instruction->called_computations()) { + use_instruction->called_computations()) { earliest_computation_start_time = std::min(earliest_computation_start_time, hlo_live_range->computation_span_times() @@ -2859,6 +3155,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { .start); int64 parameter_time = -1; int64 last_use_time = -1; + const HloInstruction* last_use_instruction = nullptr; for (const HloPosition& position : value->positions()) { if (position.instruction->opcode() == HloOpcode::kParameter && position.instruction->parent() == called_computation) { @@ -2868,27 +3165,45 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { } } for (const HloUse& use : value->uses()) { - if (use.instruction->parent() == called_computation) { - last_use_time = std::max( - last_use_time, - hlo_live_range->instruction_schedule().at(use.instruction)); + int64 use_time = + hlo_live_range->instruction_schedule().at(use.instruction); + if (use.instruction->parent() == called_computation && + use_time > last_use_time) { + last_use_time = use_time; + last_use_instruction = use.instruction; } } if (last_use_time != -1) { CHECK_NE(parameter_time, -1); - VLOG(3) << " computation: " << called_computation->name() << ": (" + VLOG(3) << indent_string + << " computation: " << called_computation->name() << ": (" << parameter_time << ", " << last_use_time << ")"; - TF_RETURN_IF_ERROR(add_allocation_and_verify( - parameter_time, last_use_time, chunk, value)); + CHECK(last_use_instruction); + if (last_use_instruction->opcode() == HloOpcode::kConditional) { + // The last use is another (nested) conditional. Call this + // function recursively. + TF_RETURN_IF_ERROR(split_conditional_buffer( + last_use_instruction, parameter_time, last_use_time, + absl::StrCat(indent_string, " "))); + } else { + TF_RETURN_IF_ERROR(add_allocation_and_verify( + parameter_time, last_use_time, chunk, value)); + } } } - VLOG(3) << " from beginning until first computation: (" - << time_bound.start << ", " - << (earliest_computation_start_time - 1) << ")"; + VLOG(3) << indent_string << " from beginning until first computation: (" + << start_time << ", " << (earliest_computation_start_time - 1) + << ")"; TF_RETURN_IF_ERROR(add_allocation_and_verify( - time_bound.start, earliest_computation_start_time - 1, chunk, - value)); - } else { + start_time, earliest_computation_start_time - 1, chunk, value)); + return Status::OK(); + }; + + if (last_use_instruction && + last_use_instruction->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR(split_conditional_buffer( + last_use_instruction, time_bound.start, time_bound.end, " ")); + } else if (!value->uses().empty()) { VLOG(3) << " buffer: " << buffer.ToString() << " value: " << value->ToShortString() << ": (" << time_bound.start << ", " << time_bound.end diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 5c5329033fd..d366c06a599 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h" namespace xla { @@ -84,6 +85,8 @@ class MemorySpaceAssignmentCostAnalysis { absl::flat_hash_map while_nest_multiplier; }; + virtual ~MemorySpaceAssignmentCostAnalysis() = default; + static StatusOr> Create( const HloCostAnalysis& cost_analysis, float async_copy_bandwidth_bytes_per_second, @@ -126,18 +129,23 @@ class MemorySpaceAssignmentCostAnalysis { // BufferInterval is prefetched. float GetInstructionElapsedDueToMemorySlowdown(int64 bytes) const; + // Returns the estimated elapsed duration of the instruction in seconds. It + // assumes all operands and outputs of the instruction are in the default + // memory. + virtual float GetInstructionElapsed(const HloInstruction& instruction) const; + // Returns the estimated elapsed duration of the instruction in seconds. It // assumes all operands and outputs of the instruction are in the default // memory, except for the operand number that is in the alternate memory, if // provided, or output if output_in_alternate_mem is true. - float GetInstructionElapsed( + virtual float GetInstructionElapsedInAlternateMemory( const HloInstruction& instruction, - absl::optional operand_in_alternate_mem = absl::nullopt, - bool output_in_alternate_mem = false) const; + absl::optional operand_in_alternate_mem, + bool output_in_alternate_mem) const; // Returns the elapsed time it would take to asynchronously copy the shape // from default to alternate memory space (or vice versa). - float GetAsyncCopyElapsed(const Shape& shape) const; + virtual float GetAsyncCopyElapsed(const Shape& shape) const; int64 GetScheduleEndTime() const; @@ -147,7 +155,7 @@ class MemorySpaceAssignmentCostAnalysis { const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } - private: + protected: MemorySpaceAssignmentCostAnalysis( const HloCostAnalysis& cost_analysis, float async_copy_bandwidth_bytes_per_second, @@ -164,6 +172,7 @@ class MemorySpaceAssignmentCostAnalysis { hlo_live_range_(std::move(hlo_live_range)), call_graph_(std::move(call_graph)) {} + private: const HloCostAnalysis& cost_analysis_; float async_copy_bandwidth_bytes_per_second_; float alternate_mem_bandwidth_bytes_per_second_; @@ -190,6 +199,17 @@ class PrefetchIntervalPicker { virtual int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, int64 latest_end_time) const = 0; + // Returns the latest time that a prefetch can start. + virtual int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time, + int64 end_time) const = 0; + + // Returns the latest time that a prefetch can end that is less than or equal + // to proposed_prefetch_end_time. + virtual int64 LatestPrefetchEndTime(int64 original_prefetch_end_time, + int64 proposed_prefetch_end_time) const { + return proposed_prefetch_end_time; + } + // Begins the iterator for the first start time of the prefetch. virtual void Begin(const HloUse& use, int64 start_time, int64 end_time) = 0; @@ -248,6 +268,9 @@ class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, int64 latest_end_time) const override; + int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time, + int64 end_time) const override; + void Begin(const HloUse& use, int64 start_time, int64 end_time) override; int64 Next() override; @@ -267,16 +290,16 @@ class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { // Prefetch interval picker that uses cost analysis to overlap asynchronous // copies with independent computation. It uses min/max (asynchronous copy // duration) / (independent computation duration) ratios to guide whether the -// prefetch is within those bounds. It starts with the maximum allowed ratio -// (earliest prefetch) in Begin() and works its way for later and later prefetch -// with each Next() call until hitting the minimum ratio, in order not to hurt -// the critical path. +// prefetch is within those bounds. It starts with the preferred ratio in +// Begin() and works its way for alternately earlier and later prefetches until +// hitting min and max ratios. class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { public: CostAnalysisPrefetchIntervalPicker( const MemorySpaceAssignmentCostAnalysis& cost_analysis, float min_async_copy_to_overlap_ratio, - float max_async_copy_to_overlap_ratio); + float max_async_copy_to_overlap_ratio, + float preferred_async_copy_to_overlap_ratio); bool CanAllocateInAlternateMemoryNoCopy(const Shape& shape, int64 start_time, int64 end_time) const override; @@ -284,6 +307,11 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, int64 latest_end_time) const override; + int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time, + int64 end_time) const override; + int64 LatestPrefetchEndTime(int64 original_prefetch_end_time, + int64 proposed_prefetch_end_time) const override; + void Begin(const HloUse& use, int64 start_time, int64 end_time) override; int64 Next() override; @@ -319,13 +347,17 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { const MemorySpaceAssignmentCostAnalysis& cost_analysis_; float min_async_copy_to_overlap_ratio_; float max_async_copy_to_overlap_ratio_; + float preferred_async_copy_to_overlap_ratio_; float max_overlap_multiplier_ = 1.0; float async_copy_elapsed_; float inst_elapsed_reduction_; int64 end_logical_time_; - int64 earliest_start_logical_time_; - int64 current_logical_prefetch_time_; + int64 earliest_prefetch_time_; + int64 latest_prefetch_time_; + bool using_increasing_prefetch_time_iterator_; + int64 increasing_prefetch_time_iterator_; + int64 decreasing_prefetch_time_iterator_; }; // MemorySpaceAssignment assigns memory spaces (default or alternate) to each @@ -348,6 +380,9 @@ class MemorySpaceAssignment { // space and a fast and small alternate memory space. enum class MemorySpace { kDefault, kAlternate }; + // Forward declaration for Allocation. + class Allocation; + // The different options to be passed to the Run() API. struct Options { // Backend-specific integer value that describes the alternate memory. @@ -383,11 +418,25 @@ class MemorySpaceAssignment { // max_outstanding_prefetches). int64 while_use_extra_outstanding_prefetch_limit = 0; + // Specifies the maximum number of times we are willing to move a copy + // done of a prefetch earlier due to an asynchronous copy ordering + // violation. + int64 prefetch_copy_done_reorder_max_retries = 1; + // Specifies the maximum number of retries that will be performed for each // value in case prefetching failed due to running out of asynchronous // copies or asynchronous copy ordering. int64 max_retries = 1; + // The maximum number of repacks that we are willing to perform in case we + // can't allocate a buffer due to running out of memory. If this value is + // greater than 0, repacker must be non-nullptr. + int64 max_repacks = 0; + + // The repacking algorithm to reduce fragmentation. Must be non-null if + // max_repacks is greater than 0. + MemorySpaceAssignmentRepacker* repacker = nullptr; + // If true, tries allocating buffers across (e.g., before and inside a while // loop body) sequential calls (kWhile, kCall, and kConditional). bool allocate_across_sequential_calls = false; @@ -475,10 +524,12 @@ class MemorySpaceAssignment { const std::vector& uses() const { return uses_; } MemorySpace memory_space() const { return memory_space_; } Chunk chunk() const { return *chunk_; } + Chunk* mutable_chunk() { return &*chunk_; } void set_start_time(int64 start_time) { start_time_ = start_time; } int64 start_time() const { return start_time_; } int64 end_time() const { return end_time_; } + bool operator==(const Allocation& other) const; virtual std::string ToString() const; protected: @@ -501,6 +552,9 @@ class MemorySpaceAssignment { }; // This class represents an allocation as a result of an asynchronous copy. + // Note: CopyStart instructions are inserted after `start_time` or later, + // while CopyDone instructions are inserted before + // `copy_done_schedule_before_time` or earlier. class CopyAllocation : public Allocation { public: CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space, @@ -550,6 +604,7 @@ class MemorySpaceAssignment { copy_start_schedule_after_ = copy_start_schedule_after; } + bool operator==(const CopyAllocation& other) const; std::string ToString() const override; private: @@ -646,13 +701,15 @@ class MemorySpaceAssignment { std::vector aliases; }; - AllocationValue(const HloValue* value, const HloPosition& position) - : value_(value), defining_position_(position) {} + AllocationValue(const HloValue* value, const HloPosition& position, + int64 size) + : value_(value), defining_position_(position), size_(size) {} const HloPosition& defining_position() const { return defining_position_; } const HloInstruction* defining_instruction() const { return defining_position().instruction; } + int64 size() const { return size_; } const std::vector& uses() const { return uses_; } std::vector& uses() { return uses_; } const HloValue* value() const { return value_; } @@ -671,6 +728,7 @@ class MemorySpaceAssignment { private: const HloValue* value_; HloPosition defining_position_; + int64 size_; std::vector uses_; AllocationSequence allocation_sequence_; }; @@ -835,9 +893,9 @@ class AsynchronousCopyOrdering { // Removes an asynchronous copy. CHECKs that it is removed. void RemoveCopy(const AsynchronousCopy& copy); - // Returns true if the addition of an asynchronous copy in the the given time - // interval would violate the asynchronous copy ordering. E.g., consider the - // following scenario: + // If the addition of an asynchronous copy in the given time interval would + // violate the asynchronous copy ordering, returns the violating + // already-committed asynchronous copy. E.g., consider the following scenario: // CS CD // already committed async copy: +-----------+ // new async copy: +--------+ @@ -845,7 +903,8 @@ class AsynchronousCopyOrdering { // The new asynchronous copy would violate the ordering guarantee because the // copy start is after an already committed asynchronous copy while its copy // done is before the committed copy. - bool ViolatesOrdering(int64 start_time, int64 end_time) const; + absl::optional ViolatesOrdering(int64 start_time, + int64 end_time) const; private: // Stores asynchronous copies in a tree set respecting the pipelining order. @@ -884,6 +943,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { HeapSimulator::Result Finish() override; private: + using RepackAllocationBlock = MemorySpaceAssignmentRepacker< + MemorySpaceAssignment::Allocation*>::AllocationBlock; + // An allocation request for a use segment. A use segment is the time segment // between the definition and the first use, and the time segment between the // uses of a buffer. For example, the time between the definition and Use1, is @@ -916,6 +978,62 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { MemorySpaceAssignment::AllocationValue* allocation_value; }; + // Result of an allocation, prefetch, eviction etc. request. The result is + // either kSuccess or a bitwise OR of one or more failures. The values are + // unique powers of two. To check if a result contains a particular failure, + // use the result_is method. To add a new failure to a result, use the + // result_mark method. + enum class Result { + // Successful allocation. + kSuccess = 0, + // Allocation failed because we ran out of alternate memory. + kFailOutOfMemory = 1, + // A no-copy allocation couldn't be performed because the previous + // allocation wasn't in the alternate memory space. + kFailPrevAllocationNotInAlternateMem = 2, + // A no-copy allocation couldn't be performed because the live range was too + // long. + kFailLiveRangeTooLong = 4, + // A prefetching couldn't be performed because the live range was too short. + kFailLiveRangeTooShort = 8, + // Ran out of outstanding asynchronous copy limit either during prefetching + // or eviction. + kFailOutOfAsyncCopies = 16, + // A prefetching couldn't be performed because the asynchronous copy + // ordering was violated. + kFailViolatesAsyncCopyOrdering = 32, + // An allocation failure happened that requires uncommitting all the pending + // allocations. Usually this is due to a situation requiring an eviction but + // the eviction couldn't be performed. + kFailRequiresUncommit = 64 + }; + + // Return true if the result belongs to a failure. + static bool result_is(Result result, Result failure) { + return static_cast(result) & static_cast(failure); + } + + // Mark (bitwise OR) a failure to the result. + static Result result_mark(Result failure, Result& result) { + result = static_cast(static_cast(result) | + static_cast(failure)); + return result; + } + + // Return true if the result is a failure that requires us to uncommit pending + // chunks. + static bool result_requires_uncommit(Result result) { + return result_is(result, Result::kFailRequiresUncommit); + } + + // Return true if the result is a failure either due to running out of + // outstanding asynchronous copies or due to violating asynchronous copy + // ordering. + static bool result_failed_because_of_async_copy(Result result) { + return result_is(result, Result::kFailOutOfAsyncCopies) || + result_is(result, Result::kFailViolatesAsyncCopyOrdering); + } + // Given an allocation sequence, returns the live allocation at time with a // preference towards allocations in alternate memory. Returns nullptr if no // allocation is alive at that time. @@ -926,17 +1044,24 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { bool IsUseAllowedInAlternateMemory(const AllocationValue& value, const HloUse& use) const; - // Given an HloValue, creates AllocationValue objects and corresponding + // Given a BufferInterval, creates AllocationValue objects and corresponding // AllocationSequences and appends them into allocation_sequence_list_. - void CreateAllocationValues(const HloValue* value, - std::vector* allocation_values); + void CreateAllocationValues( + const BufferInterval& buffer_interval, + std::vector& allocation_values) const; - // Finds allocations for colocated intervals. Colocated intervals consist of - // one or more BufferIntervals, each with a different HloValue. All of the - // intervals within colocated intervals have a must-alias relationship with - // each other. Returns true if allocation succeeded. - bool AllocateColocatedIntervals( - const std::vector& colocated_intervals); + // Given colocated intervals, populates allocation_values with the + // corresponding AllocationValue objects. + void CreateAllocationValuesFromColocatedIntervals( + absl::Span colocated_intervals, + std::vector& allocation_values); + + // Finds allocations for allocation values generated from colocated intervals. + // All of the allocation values have a must-alias relationship with each + // other. Returns either kSuccess if all of the sites could be placed in the + // alternate memory or a bitwise OR of failure reasons why they couldn't + Result AllocateAllocationValues( + absl::Span allocation_values); // Go through all the uses in the AllocationValues and find the aliasing // positions. @@ -954,20 +1079,26 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // if there is enough space and if the prefetch interval picker allows. // // If an eviction (2) was requested and was unsuccessful, this method returns - // false. This means we could not find a suitable allocation, so all previous - // allocations for this buffer must be removed and allocated in the default - // memory. Otherwise, this method returns true. - bool AllocateSegment(const AllocationRequest& request); + // Result::kFailRequiresUncommit. This means we could not find a suitable + // allocation, so all previous allocations for this buffer must be removed and + // allocated in the default memory. Otherwise, this method may return + // Result::kSuccess if the buffer could be placed in alternate memory or some + // other Result with an OR of reasons why the buffer couldn't be placed in + // alternate memory. + Result AllocateSegment(const AllocationRequest& request); - // Try allocating in alternate memory without any copies. Returns true if - // successful. - bool AllocateInAlternateMemoryNoCopy(const AllocationRequest& request); + // Try allocating in alternate memory without any copies. + Result AllocateInAlternateMemoryNoCopy(const AllocationRequest& request); - // Try evicting to default memory space. Returns true if successful. - bool Evict(const AllocationRequest& request); + // Try evicting to default memory space. + Result Evict(const AllocationRequest& request); - // Try prefetching to alternate memory space. Returns true if successful. - bool Prefetch( + // Returns the time a copy done of a prefetch should be scheduled. + int64 FindPrefetchEndTime(const AllocationRequest& request, + int64 earliest_prefetch_time) const; + + // Try prefetching to alternate memory space. + Result Prefetch( const AllocationRequest& request, const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem); @@ -1030,8 +1161,20 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { int64 start_time, int64 end_time, bool is_prefetch, int64 extra_async_copy_limit = 0) const; - // Return true if the asynchronous copy would violate the pipelining order. - bool ViolatesAsyncCopyOrdering(int64 start_time, int64 end_time) const; + // If the asynchronous copy would violate the pipelining order, returns the + // violating asynchronous copy. + absl::optional ViolatesAsyncCopyOrdering( + int64 start_time, int64 end_time) const; + + // Exports the allocations for repacking and puts them into the vector in the + // parameter. + void ExportAllocationsForRepacking( + std::vector& allocations); + + // Imports repacked allocations and updates the internal data structures + // consistent with the new packing. + void ImportRepackedAllocations( + absl::Span repacked_allocations); // Adds an asynchronous copy to the allocations. void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, @@ -1047,17 +1190,24 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { const ChunkCandidate& chunk_candidate); // If we need to remove the allocations for this allocation sequence, this // removes pending chunks and asynchronous copies in the respective pending - // buffers from the interval trees. - void UncommitPendingChunks(); + // buffers from the interval trees. If an allocation request returns + // kFailRequiresUncommit, this method must be called. + void UncommitPendingChunks(absl::Span allocation_values); + + // Finalizes the allocations where they can no longer be uncommitted. + void FinalizeAllocations(absl::Span allocation_values); + + // Clears all pending chunks and asynchronous copies. + void ClearPendingChunks(); // Append buffer and allocation infos for debugging and dump it into a file, // if enabled. void AppendBufferInfoDebugString(const BufferInterval& interval, std::string* debug_str) const; void AppendAllocationInfoDebugString( - const BufferInterval& interval, + const AllocationValue& value, const MemorySpaceAssignment::Allocation& allocation, - std::string* debug_str) const; + std::string& debug_str) const; void DumpDebugStringsIfEnabled() const; // Returns the available heap size in the alternate memory. @@ -1074,6 +1224,11 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { BufferIntervalTree prefetch_interval_tree_; BufferIntervalTree eviction_interval_tree_; AsynchronousCopyOrdering async_copy_ordering_; + // A list of RepackAllocationBlock objects that mirrors allocation sequences, + // used for repacking. We use a list here because we need pointer stability + // for aliased allocations. + std::list repack_allocation_blocks_; + int64 num_repacks_ = 0; std::vector> pending_chunks_; std::vector pending_async_copies_; std::vector> @@ -1084,9 +1239,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { required_assignments_; // Number of bytes reserved in alternate memory space. int64 reserved_in_bytes_ = 0; - // Variables to control allocation retries. - bool final_retry_; - bool prefetch_failed_due_to_async_copy_; // Debug strings. std::string buffer_info_str_; std::string allocation_info_str_; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_repacking.h b/tensorflow/compiler/xla/service/memory_space_assignment_repacking.h new file mode 100644 index 00000000000..fcfdfc797fb --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_repacking.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_REPACKING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_REPACKING_H_ + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// An interface to define allocation repacking algorithms. +template +class MemorySpaceAssignmentRepacker { + public: + MemorySpaceAssignmentRepacker() = default; + virtual ~MemorySpaceAssignmentRepacker() = default; + + // A contiguous block of allocation consisting of start and end (logical) + // times, size, and the initial offset. After repacking, if the repacking was + // successful and the allocations were modified, the offset field holds the + // new offset. To support aliased allocations, AllocationBlock also includes a + // vector of AllocationBlock pointers, called colocations. All AllocationBlock + // objects within the colocations must get the same offset. The opaque field + // is used by the MemorySpaceAssignment pass and should not be accessed by the + // repacking algorithm. + struct AllocationBlock { + int64 start_time; + int64 end_time; + int64 size; + int64 offset; + int64 initial_offset; + std::vector colocations; + O opaque; + }; + + // Repack the AllocationBlocks provided in the parameter. Returns true if + // allocations have been modified and false if not. Returns a non-ok status if + // there was an error. + virtual StatusOr Repack(absl::Span allocations) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_REPACKING_H_ diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 10e11e55291..464cfb502be 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -60,7 +60,8 @@ class MemorySpaceAssignmentTest : public HloTestBase, CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( CostAnalysisPrefetchIntervalPicker( *cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8, - /*max_async_copy_to_overlap_ratio=*/10.0)); + /*max_async_copy_to_overlap_ratio=*/10.0, + /*preferred_async_copy_to_overlap_ratio=*/1.5)); return AssignMemorySpace( module, /*max_outstanding_async_copies=*/-1, MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( @@ -70,19 +71,22 @@ class MemorySpaceAssignmentTest : public HloTestBase, std::unique_ptr AssignMemorySpace( HloModule* module, int64 max_outstanding_async_copies = -1, - int64 max_prefetch_interval = 10, int64 min_prefetch_interval = 2) { + int64 max_prefetch_interval = 10, int64 min_prefetch_interval = 2, + absl::optional options = absl::nullopt) { InstructionCountPrefetchIntervalPicker prefetch_interval_picker( min_prefetch_interval, max_prefetch_interval); return AssignMemorySpace(module, max_outstanding_async_copies, /*buffer_interval_compare=*/{}, - &prefetch_interval_picker); + &prefetch_interval_picker, options); } std::unique_ptr AssignMemorySpace( HloModule* module, int64 max_outstanding_async_copies, absl::optional buffer_interval_compare, - PrefetchIntervalPicker* prefetch_interval_picker) { + PrefetchIntervalPicker* prefetch_interval_picker, + absl::optional + memory_space_assignment_options = absl::nullopt) { auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; @@ -116,9 +120,15 @@ class MemorySpaceAssignmentTest : public HloTestBase, } MemorySpaceAssignment::Options options; + if (memory_space_assignment_options) { + options = *memory_space_assignment_options; + } else { + options.max_size_in_bytes = 128; + options.alignment_in_bytes = 8; + options.verify = true; + } + options.alternate_memory_space = kAlternateMemorySpace; - options.max_size_in_bytes = 128; - options.alignment_in_bytes = 8; options.buffer_interval_compare = buffer_interval_compare; options.prefetch_interval_picker = prefetch_interval_picker; options.size_fn = size_fn; @@ -126,7 +136,6 @@ class MemorySpaceAssignmentTest : public HloTestBase, options.max_outstanding_prefetches = max_outstanding_async_copies; options.max_outstanding_evictions = max_outstanding_async_copies; options.allocate_across_sequential_calls = GetParam(); - options.verify = true; auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie(); std::unique_ptr hlo_live_range = @@ -285,6 +294,92 @@ class MemorySpaceAssignmentTest : public HloTestBase, MemorySpaceAssignmentCostAnalysis::Cache cache_; }; +// For testing purposes, we define a cost analysis where we can control the +// elapsed times of each HLO and asynchronous copy. +class FakeMemorySpaceAssignmentCostAnalysis + : public MemorySpaceAssignmentCostAnalysis { + public: + static StatusOr> + Create(const HloCostAnalysis& cost_analysis, const HloModule& module) { + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); + TF_ASSIGN_OR_RETURN(auto hlo_live_range, + HloLiveRange::Run(module.schedule(), *alias_analysis, + module.entry_computation())); + auto call_graph = CallGraph::Build(&module); + return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis( + cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1, + /*alternate_mem_bandwidth_bytes_per_second=*/1, + std::move(alias_analysis), std::move(hlo_live_range), + std::move(call_graph))); + } + + float GetInstructionElapsed( + const HloInstruction& instruction) const override { + if (get_instruction_elapsed_override_) { + return get_instruction_elapsed_override_(instruction); + } + return 1.0; + } + + float GetInstructionElapsedInAlternateMemory( + const HloInstruction& instruction, + absl::optional operand_in_alternate_mem, + bool output_in_alternate_mem) const override { + if (get_instruction_elapsed_in_alternate_memory_override_) { + return get_instruction_elapsed_in_alternate_memory_override_( + instruction, operand_in_alternate_mem, output_in_alternate_mem); + } + if (operand_in_alternate_mem) { + return 0.5; + } else { + return 1.0; + } + } + + float GetAsyncCopyElapsed(const Shape& shape) const override { + if (get_async_copy_elapsed_override_) { + return get_async_copy_elapsed_override_(shape); + } + return 3.0; + } + + // The following methods can be used to override what the above API calls + // return. + void SetOverrideForGetInstructionElapsed( + std::function function) { + get_instruction_elapsed_override_ = function; + } + void SetOverrideForGetInstructionElapsedInAlternateMemory( + std::function, bool)> + function) { + get_instruction_elapsed_in_alternate_memory_override_ = function; + } + void SetOverrideForGetAsyncCopyElapsed( + std::function function) { + get_async_copy_elapsed_override_ = function; + } + + protected: + FakeMemorySpaceAssignmentCostAnalysis( + const HloCostAnalysis& cost_analysis, + float async_copy_bandwidth_bytes_per_second, + float alternate_mem_bandwidth_bytes_per_second, + std::unique_ptr alias_analysis, + std::unique_ptr hlo_live_range, + std::unique_ptr call_graph) + : MemorySpaceAssignmentCostAnalysis( + cost_analysis, async_copy_bandwidth_bytes_per_second, + alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis), + std::move(hlo_live_range), std::move(call_graph)) {} + + private: + std::function + get_instruction_elapsed_override_ = nullptr; + std::function, bool)> + get_instruction_elapsed_in_alternate_memory_override_ = nullptr; + std::function get_async_copy_elapsed_override_ = nullptr; +}; + TEST_P(MemorySpaceAssignmentTest, ParameterOnly) { // A module consisting of a single parameter. Inputs/outputs are currently // excluded from memory space assignment. @@ -1718,6 +1813,59 @@ TEST_P(MemorySpaceAssignmentTest, WhileInPlaceBuffer) { } } +TEST_P(MemorySpaceAssignmentTest, WhileSharedBufferVerificationBug) { + // Tests a spurious verification failure when a while has the same value + // passed in twice (copy0) and that value is evicted within the while loop. + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + while_cond { + p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + ROOT gte = pred[] get-tuple-element(p0), index=3 + } + + while_body { + p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0) + gte0 = f32[3]{0} get-tuple-element(p0), index=0 + gte1 = f32[3]{0} get-tuple-element(p0), index=1 + gte2 = f32[3]{0} get-tuple-element(p0), index=2 + gte3 = pred[] get-tuple-element(p0), index=3 + add = f32[3]{0} add(gte0, gte0) + negate0 = f32[3]{0} negate(add) + negate1 = f32[3]{0} negate(negate0) + negate2 = f32[3]{0} negate(negate1) + negate3 = f32[3]{0} negate(negate2) + negate4 = f32[3]{0} negate(negate3) + negate5 = f32[3]{0} negate(negate4) + negate6 = f32[3]{0} negate(negate5) + negate7 = f32[3]{0} negate(negate6) + negate8 = f32[3]{0} negate(negate7) + negate9 = f32[3]{0} negate(negate8) + negate10 = f32[3]{0} negate(negate9) + negate11 = f32[3]{0} negate(negate10) + negate12 = f32[3]{0} negate(negate11) + negate13 = f32[3]{0} negate(negate12) + negate14 = f32[3]{0} negate(negate13) + negate15 = f32[3]{0} negate(negate14) + negate16 = f32[3]{0} negate(negate15) + ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, gte0, negate16, gte3) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy0 = f32[3]{0} copy(p0) + copy1 = f32[3]{0} copy(p0) + tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1) + while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body + ROOT gte = f32[3]{0} get-tuple-element(while), index=2 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); +} + TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) { // Having control_predecessors on an HLO was preventing us from DCEing an op // that doesn't have any users (tuple.1). The scheduler assumes the graph is @@ -2066,6 +2214,58 @@ TEST_P(MemorySpaceAssignmentTest, NestedConditional) { } } +TEST_P(MemorySpaceAssignmentTest, NestedConditionalBufferReuseVerificationBug) { + // Tests a spurious verification failure when there are nested conditionals + // and the innermost conditional computation reuses the buffer. Here, both the + // parameter of true_computation2 and neg2 will get the same buffer. Make sure + // that verification doesn't claim a failure in this case. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation2 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + neg1 = f32[3]{0} negate(gte) + neg2 = f32[3]{0} negate(neg1) + ROOT neg3 = f32[3]{0} negate(neg2) + } + + false_computation2 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg4 = f32[3]{0} negate(gte) + } + + true_computation1 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + slice = f32[1]{0} slice(gte), slice={[0:1]} + bitcast = f32[] bitcast(slice) + constant = f32[] constant(0.0) + compare = pred[] compare(bitcast, constant), direction=GT + tuple = (f32[3]{0}) tuple(gte) + ROOT conditional = f32[3]{0} conditional(compare, tuple, tuple), true_computation=true_computation2, false_computation=false_computation2 + } + + false_computation1 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg5 = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy = f32[3]{0} copy(p0) + tuple = (f32[3]{0}) tuple(copy) + ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); +} + TEST_P(MemorySpaceAssignmentTest, RequestIdentifierShouldNotBeAllocatedInAlternateMem) { // Ensure that request identifier returned by Send/Recv HLOs are not allocated @@ -3749,6 +3949,286 @@ TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) { buffer_interval_compare, &prefetch_interval_picker); } +TEST_P(MemorySpaceAssignmentTest, MoveCopyDoneEarlier) { + // This tests the case where an earlier placed smaller buffer may block a + // larger buffer due to asynchronous copy ordering. The smaller buffer (the + // operand of sin) will be placed first. The cos, whose operand is 3 times + // larger than sin's, needs longer time for the asynhronous copy. The cos is + // placed right after sin, leading to a copy ordering violation: + // + // param1------------------>CS----->CD->sin + // param0------------->CS------------------->CD->cos + // + // To fix this, we need to move copy done for cos earlier and ensure both of + // these buffers get alternate memory allocations: + // + // param1------------------>CS----->CD->sin + // param0-->CS------------------->CD------------>cos + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + ENTRY Entry { + param0 = f32[8,3] parameter(0) + param1 = f32[2,4] parameter(1) + a = f32[2,4] negate(param1) + b = f32[2,4] negate(a) + c = f32[2,4] negate(b) + d = f32[2,4] negate(c) + e = f32[2,4] negate(d) + f = f32[2,4] negate(e) + g = f32[2,4] negate(f) + h = f32[2,4] negate(g) + i = f32[2,4] negate(h) + j = f32[2,4] negate(i) + k = f32[2,4] negate(j) + l = f32[2,4] negate(k) + m = f32[2,4] negate(l) + n = f32[2,4] negate(m) + sin = f32[2,4] sine(param1) + o = f32[2,4] negate(n) + cos = f32[8,3] cosine(param0) + ROOT tuple = (f32[8,3], f32[2,4], f32[2,4]) tuple(cos, sin, o) + } + )"; + + MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = + [](const MemorySpaceAssignment::BufferInterval& a, + const MemorySpaceAssignment::BufferInterval& b) { + auto get_opcode_priority = [](const HloOpcode& opcode) { + switch (opcode) { + case HloOpcode::kSin: + return 0; + case HloOpcode::kCos: + return 1; + case HloOpcode::kTanh: + return 2; + default: + return 3; + } + }; + + auto get_user_priority = [&](const HloValue& value) { + int priority = INT_MAX; + for (const auto& use : value.uses()) { + priority = std::min(priority, + get_opcode_priority(use.instruction->opcode())); + } + return priority; + }; + + return get_user_priority(*a.buffer) < get_user_priority(*b.buffer); + }; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, + FakeMemorySpaceAssignmentCostAnalysis::Create( + hlo_cost_analysis, *module)); + cost_analysis->SetOverrideForGetAsyncCopyElapsed([](const Shape& shape) { + // This should return 2 for f32[2,4] and 6 for f32[8,3]. + return ShapeSize(shape) / 16; + }); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_async_copy_to_overlap_ratio=*/1.0, + /*max_async_copy_to_overlap_ratio=*/4.0, + /*preferred_async_copy_to_overlap_ratio=*/1.5); + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + buffer_interval_compare, &interval_picker); + + // Check that both cos and sin could get their operands prefetched. + const HloInstruction* cos = + module->entry_computation()->GetInstructionWithName("cos"); + const HloInstruction* sin = + module->entry_computation()->GetInstructionWithName("sin"); + EXPECT_THAT(sin->operand(0), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + op::Parameter(1))); + EXPECT_THAT(cos->operand(0), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + op::Parameter(0))); + + // Sanity check that the cos' operand copy-done is scheduled earlier than + // sin's operand. + auto find_schedule_index = [&](const HloInstruction* instruction) { + const auto& instructions = + module->schedule().sequence(module->entry_computation()).instructions(); + for (int i = 0; i < instructions.size(); ++i) { + if (instruction == instructions[i]) { + return i; + } + } + CHECK(false); + return -1; + }; + EXPECT_GT(find_schedule_index(sin->operand(0)), + find_schedule_index(cos->operand(0))); +} + +// A mock MemorySpaceAssignmentRepacker class that accepst a map of +// (start_time,offset) -> new_offset values. Using this map, the repacker +// repacks the allocations to the new_offset. +class FakeMemorySpaceAssignmentRepacker + : public MemorySpaceAssignmentRepacker { + public: + FakeMemorySpaceAssignmentRepacker( + absl::flat_hash_map, int64>& repack_map) + : repack_map_(repack_map) {} + + StatusOr Repack(absl::Span allocations) override { + bool modified = false; + for (AllocationBlock* block : allocations) { + VLOG(1) << "Alloc time: [" << block->start_time << ", " << block->end_time + << "] size: " << block->size + << " init offset: " << block->initial_offset; + auto it = repack_map_.find({block->start_time, block->initial_offset}); + if (it != repack_map_.end()) { + modified = true; + block->offset = it->second; + } else { + block->offset = block->initial_offset; + } + for (AllocationBlock* colocation : block->colocations) { + VLOG(1) << " [" << colocation->start_time << ", " + << colocation->end_time << "]"; + if (it != repack_map_.end()) { + colocation->offset = it->second; + } else { + colocation->offset = colocation->initial_offset; + } + } + } + + return modified; + } + + private: + // A map from (start_time, offset) to new_offset. + absl::flat_hash_map, int64> repack_map_; +}; + +TEST_P(MemorySpaceAssignmentTest, Repack) { + // We initially perform the following allocations at these offsets. + // + // Max memory + // ------------------------------------------- + // + // + // + // + // +------------+ + // | b | + // +------------+ + // +-------+ +------------+ + // | a | | n | + // +-------+ +------------+ + // ------------------------------------------- + // Min memory time -> + // + // Next up, we try to allocate the prefetch for m. However due to + // fragmentation, this won't be possible: + // + // Max memory + // ------------------------------------------- + // + // + // + // +---------+ + // +------------+ | + // | b | | | + // +------------+ | + // +-------+ | | +------------+ + // | a | | d | | n | + // +-------+ +---------+ +------------+ + // ------------------------------------------- + // Min memory time -> + // + // We then call repack to repack the existing allocations which allows us to + // allocate the prefetch for m: + // + // Max memory + // ------------------------------------------- + // +---------+ + // | | + // | | + // | | + // +-------+ | | + // | a | | d | + // +-------+ +---------+ + // +------------+ +------------+ + // | b | | n | + // +------------+ +------------+ + // ------------------------------------------- + // Min memory time -> + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + ENTRY Entry { + param0 = f32[8,3] parameter(0) + param1 = f32[2,4] parameter(1) + a = f32[2,4] sine(param1) + b = f32[2,4] cosine(param1) + c = f32[8,3] negate(param0) + j = f32[2,4] negate(a) + d = f32[8,3] tanh(param0) + k = f32[2,4] negate(j) + l = f32[2,4] add(b, k) + m = f32[8,3] negate(d) + n = f32[2,4] sine(l) + o = f32[8,3] negate(m) + p = f32[2,4] negate(n) + q = f32[8,3] negate(m) + ROOT tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(p, q, o) + } + )"; + + MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare = + [](const MemorySpaceAssignment::BufferInterval& a, + const MemorySpaceAssignment::BufferInterval& b) { + auto get_opcode_priority = [](const HloOpcode& opcode) { + switch (opcode) { + case HloOpcode::kSin: + return 0; + case HloOpcode::kCos: + return 1; + case HloOpcode::kTanh: + return 2; + default: + return 3; + } + }; + + return get_opcode_priority(a.buffer->defining_instruction()->opcode()) < + get_opcode_priority(b.buffer->defining_instruction()->opcode()); + }; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10); + absl::flat_hash_map, int64> repack_map; + // Move "a" from offset 0 to 32. + repack_map[{2, 0}] = 32; + // Move "b" from offset 32 to 0. + repack_map[{3, 32}] = 0; + FakeMemorySpaceAssignmentRepacker repacker = + FakeMemorySpaceAssignmentRepacker(repack_map); + MemorySpaceAssignment::Options options; + options.max_size_in_bytes = 128; + options.alignment_in_bytes = 8; + options.verify = true; + options.max_repacks = 1; + options.repacker = &repacker; + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + buffer_interval_compare, &prefetch_interval_picker, + options); + + // If repacking succeeds, we should find the buffer for d in alternate memory. + const HloInstruction* d = + module->entry_computation()->GetInstructionWithName("d"); + EXPECT_EQ(d->shape().layout().memory_space(), kAlternateMemorySpace); +} + TEST_P(MemorySpaceAssignmentTest, Determinism) { // Run memory space assignment a few times to make sure every time it compiles // to the same thing. @@ -4045,5 +4525,278 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) { EXPECT_EQ(cross_program_prefetches.size(), 0); } +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTest) { + HloComputation::Builder builder(TestName()); + + constexpr int kBatch = 8; + constexpr int kFeature = 8; + constexpr int kOutput = 2; + + auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature}); + auto rhs_shape = ShapeUtil::MakeShapeWithLayout( + F32, {kFeature, kOutput}, + /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0, + kAlternateMemorySpace); + auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput}); + auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "p0")); + + auto lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs_shape, param, 0)); + auto rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs_shape, param, 1)); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); + + auto module = CreateNewVerifiedModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, lhs, rhs, dot}); + TF_CHECK_OK(module->set_schedule(schedule)); + + AssignMemorySpace(module.get()); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 0); +} + +using CostAnalysisPrefetchIntervalPickerTest = HloTestBase; + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + ENTRY Entry { + param0 = f32[2,4] parameter(0) + a = f32[2,4] negate(param0) + b = f32[2,4] negate(a) + c = f32[2,4] negate(b) + d = f32[2,4] negate(c) + e = f32[2,4] negate(d) + f = f32[2,4] negate(e) + g = f32[2,4] negate(f) + h = f32[2,4] negate(g) + i = f32[2,4] negate(h) + j = f32[2,4] negate(i) + k = f32[2,4] negate(j) + l = f32[2,4] negate(k) + m = f32[2,4] negate(l) + n = f32[2,4] negate(m) + o = f32[2,4] negate(n) + p = f32[2,4] negate(o) + q = f32[2,4] negate(p) + r = f32[2,4] negate(q) + s = f32[2,4] negate(r) + t = f32[2,4] negate(s) + u = f32[2,4] negate(t) + ROOT v = f32[2,4] add(u, param0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, + FakeMemorySpaceAssignmentCostAnalysis::Create( + hlo_cost_analysis, *module)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_async_copy_to_overlap_ratio=*/1.0, + /*max_async_copy_to_overlap_ratio=*/4.0, + /*preferred_async_copy_to_overlap_ratio=*/2.0); + + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22); + + // Expect that the first interval is (15, 22), which has elapsed time of 6.0, + // twice of the async copy elased (3.0). Then we expect that intervals will be + // visited in alternating increasing and decreasing orders until hitting the + // min and max async copy overlap ratios, which are the intervals (18, 22) + // and (9, 22) respectively. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 15); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 16); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 14); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 17); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 13); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 18); // Min async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 12); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 11); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 10); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 9); // Max async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_TRUE(interval_picker.Done()); + + // Expect that if the time between start_time and end_time is too short, there + // won't be any available intervals. + interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_TRUE(interval_picker.Done()); +} + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) { + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + while_condition { + param1 = (f32[2,4]) parameter(0) // 19 + ROOT cond = pred[] constant(true) // 20 + } + + while_body { + param2 = (f32[2,4]) parameter(0) // 21 + gte2 = f32[2,4] get-tuple-element(param2), index=0 // 22 + add = f32[2,4] add(gte2, gte2) // 23 + ROOT tuple2 = (f32[2,4]) tuple(add) // 24 + } + + ENTRY Entry { + param0 = f32[2,4] parameter(0) // 0 + a = f32[2,4] negate(param0) // 1 + b = f32[2,4] negate(a) // 2 + c = f32[2,4] negate(b) // 3 + d = f32[2,4] negate(c) // 4 + e = f32[2,4] negate(d) // 5 + f = f32[2,4] negate(e) // 6 + g = f32[2,4] negate(f) // 7 + h = f32[2,4] negate(g) // 8 + i = f32[2,4] negate(h) // 9 + j = f32[2,4] negate(i) // 10 + k = f32[2,4] negate(j) // 11 + l = f32[2,4] negate(k) // 12 + m = f32[2,4] negate(l) // 13 + n = f32[2,4] negate(m) // 14 + o = f32[2,4] negate(n) // 15 + p = f32[2,4] negate(o) // 16 + q = f32[2,4] negate(p) // 17 + tuple = (f32[2,4]) tuple(q) // 18 + while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body // 25 + gte1 = f32[2,4] get-tuple-element(while), index=0 // 26 + r = f32[2,4] negate(gte1) // 27 + s = f32[2,4] negate(r) // 28 + t = f32[2,4] negate(s) // 29 + u = f32[2,4] negate(t) // 30 + ROOT v = f32[2,4] add(u, param0) // 31 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, + FakeMemorySpaceAssignmentCostAnalysis::Create( + hlo_cost_analysis, *module)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_async_copy_to_overlap_ratio=*/1.0, + /*max_async_copy_to_overlap_ratio=*/12.0, + /*preferred_async_copy_to_overlap_ratio=*/2.0); + + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31); + + // Because there are while loop computations between [19, 24], we ensure that + // the interval picker avoids this interval. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 25); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 26); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 18); + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 27); // Min async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_EQ(interval_picker.Next(), 17); // Max async overlap ratio reached. + LOG(INFO) << interval_picker.ToDebugString(); + EXPECT_TRUE(interval_picker.Done()); +} + +TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { + // This test is to check against a bug where we didn't assign + // while_nest_level_ for while instructions, and defaulting to 0. This could + // cause the prefetch interval logic to think a nested while instruction is + // the same level as the outermost computation. + absl::string_view hlo_string = R"( + HloModule bug, is_scheduled=true + + while_condition.2 { + param1 = (f32[2,4]) parameter(0) // 11 + ROOT cond = pred[] constant(true) // 12 + } + + while_body.2 { + param2 = (f32[2,4]) parameter(0) // 13 + gte2 = f32[2,4] get-tuple-element(param2), index=0 // 14 + add = f32[2,4] add(gte2, gte2) // 15 + ROOT tuple2 = (f32[2,4]) tuple(add) // 16 + } + + while_condition.1 { + param3 = (f32[2,4]) parameter(0) // 5 + ROOT cond = pred[] constant(true) // 6 + } + + while_body.1 { + param4 = (f32[2,4]) parameter(0) // 7 + gte1 = f32[2,4] get-tuple-element(param4), index=0 // 8 + add1 = f32[2,4] add(gte1, gte1) // 9 + tuple1 = (f32[2,4]) tuple(add1) // 10 + while = (f32[2,4]) while(tuple1), condition=while_condition.2, body=while_body.2 // 17 + gte2 = f32[2,4] get-tuple-element(while), index=0 // 18 + add2 = f32[2,4] add(gte2, gte2) // 19 + ROOT tuple2 = (f32[2,4]) tuple(add2) // 20 + } + + ENTRY Entry { + param0 = f32[2,4] parameter(0) // 0 + a = f32[2,4] negate(param0) // 1 + b = f32[2,4] negate(a) // 2 + c = f32[2,4] negate(b) // 3 + tuple = (f32[2,4]) tuple(c) // 4 + while = (f32[2,4]) while(tuple), condition=while_condition.1, body=while_body.1 // 21 + gte1 = f32[2,4] get-tuple-element(while), index=0 // 22 + ROOT root = f32[2,4] add(gte1, param0) // 23 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloCostAnalysis hlo_cost_analysis(ShapeSize); + TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis, + FakeMemorySpaceAssignmentCostAnalysis::Create( + hlo_cost_analysis, *module)); + CostAnalysisPrefetchIntervalPicker interval_picker( + *cost_analysis, + /*min_async_copy_to_overlap_ratio=*/1.0, + /*max_async_copy_to_overlap_ratio=*/12.0, + /*preferred_async_copy_to_overlap_ratio=*/2.0); + + HloInstruction* root = module->entry_computation()->root_instruction(); + const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + + // We expect the root's latest prefetch start time to be before the while loop + // (logical time 4). + EXPECT_EQ(interval_picker.LatestPrefetchStartTime(use, /*start_time=*/0, + /*end_time=*/23), + 4); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 113c9764b40..31cf36dee85 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -82,6 +82,7 @@ cc_library( ":kernel_lowering", ":lhlo_dialect_emitter", "@com_google_absl//absl/container:flat_hash_map", + "@llvm-project//llvm:Core", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", @@ -154,13 +155,33 @@ cc_library( ], ) +cc_library( + name = "passes", + srcs = ["passes.cc"], + hdrs = ["passes.h"], + deps = [ + "//tensorflow/compiler/mlir/hlo:lhlo", + "@com_google_absl//absl/memory", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFTransforms", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "kernel_lowering", srcs = ["kernel_lowering.cc"], hdrs = ["kernel_lowering.h"], deps = [ + ":passes", "//tensorflow/compiler/mlir/hlo", - "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", + "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration", "//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo", "//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation", "//tensorflow/compiler/mlir/hlo:legalize_to_linalg", @@ -172,9 +193,7 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", - "@llvm-project//mlir:Affine", "@llvm-project//mlir:AffineToStandardTransforms", "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:GPUDialect", @@ -183,7 +202,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMTransforms", - "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:NVVMDialect", @@ -192,7 +210,6 @@ cc_library( "@llvm-project//mlir:SCFToGPUPass", "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 648c44d9ac1..ae99cc9ba63 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -22,419 +22,26 @@ limitations under the License. #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" // from @llvm-project #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project -#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project #include "mlir/Dialect/GPU/Passes.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/NVVMDialect.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/Dialect/SCF/Passes.h" // from @llvm-project -#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project #include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/Region.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/BufferPlacement.h" // from @llvm-project -#include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "mlir/Transforms/LoopUtils.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" +#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h" #include "tensorflow/compiler/xla/util.h" namespace xla { namespace mlir_gpu { -namespace { - -using ::mlir::lmhlo::FusionOp; - -// Replaces a FusionOp by the operations contained in its region. -struct FusionOpRemover - : public mlir::PassWrapper { - void runOnFunction() override { - getFunction().walk([&](FusionOp op) { - mlir::OpBuilder builder(op); - // FusionOp has a single region with a single block, so we can just walk - // over it and clone operations to the outside. - mlir::BlockAndValueMapping mapping; - for (auto& nested_op : op.region().front().without_terminator()) { - auto clone = builder.clone(nested_op, mapping); - for (auto pair : - llvm::zip(nested_op.getResults(), clone->getResults())) { - mapping.map(std::get<0>(pair), std::get<1>(pair)); - } - } - op.erase(); - }); - } -}; - -// Simple pass that replaces a load that immediately follows a store to the -// same address with the stored value. This needs generalization. -struct StoreForwardingPass - : mlir::PassWrapper { - mlir::StoreOp findStore(mlir::Operation* op, - std::function matches) { - // Search from op upwards in the current block. - mlir::Block* block = op->getBlock(); - auto startFromIt = - std::find_if(block->rbegin(), block->rend(), - [op](mlir::Operation& other) { return &other == op; }); - for (auto storeOpIt = startFromIt; storeOpIt != block->rend(); - ++storeOpIt) { - auto storeOp = llvm::dyn_cast(&*(storeOpIt)); - if (!storeOp || !matches(storeOp)) { - continue; - } - - return storeOp; - } - // No store operation found. Continue search outside of the parallel - // loop if block is in a parallel loop. - if (auto parallelOp = - llvm::dyn_cast(block->getParentOp())) { - return findStore(parallelOp.getOperation(), matches); - } - return {}; - } - - // Recursively search defining ops for AllocOp. Return either AllocOp if it is - // found or nullptr. - mlir::Operation* SearchAllocOp(mlir::Value memref) { - mlir::Operation* defOp = memref.getDefiningOp(); - while (auto subviewOp = mlir::dyn_cast_or_null(defOp)) { - defOp = subviewOp.source().getDefiningOp(); - } - if (auto allocOp = mlir::dyn_cast_or_null(defOp)) { - return allocOp.getOperation(); - } - return nullptr; - } - - // Retrieves AllocOp from the cache or actually looks for it. - mlir::Operation* GetAllocOp( - mlir::Value memref, - llvm::DenseMap* memrefToAllocOp) { - auto allocOpIt = memrefToAllocOp->find(memref); - if (allocOpIt != memrefToAllocOp->end()) { - return allocOpIt->second; - } - auto allocOp = SearchAllocOp(memref); - memrefToAllocOp->insert({memref, allocOp}); - return allocOp; - } - - void runOnFunction() override { - llvm::DenseMap memrefToAllocOp; - - getFunction().walk([&](mlir::LoadOp loadOp) { - auto storeOp = findStore(loadOp, [&](mlir::StoreOp storeOp) { - mlir::Operation* storeOpAlloc = - GetAllocOp(storeOp.memref(), &memrefToAllocOp); - mlir::Operation* loadOpAlloc = - GetAllocOp(loadOp.memref(), &memrefToAllocOp); - return storeOpAlloc && loadOpAlloc && (storeOpAlloc == loadOpAlloc); - }); - if (!storeOp) { - return; - } - auto storeIndices = storeOp.getIndices(); - auto loadIndices = loadOp.getIndices(); - if (!std::equal(storeIndices.begin(), storeIndices.end(), - loadIndices.begin(), loadIndices.end())) { - return; - } - loadOp.replaceAllUsesWith(storeOp.getValueToStore()); - loadOp.erase(); - }); - } -}; - -// Simple pass that removes temporary buffers that are only written to but -// never read from or that are read but the read value is not used. -// Needs an analysis that proves that loads and stores are side-effect free -// (in bounds, no aliasing, etc.). -struct DeadTempBufferRemoval - : mlir::PassWrapper { - bool operationConsideredDead(mlir::Operation* op) { - for (auto result : op->getResults()) { - if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) { - // Store and Dealloc is OK. - if (llvm::isa(op)) { - return true; - } - // Load without uses is also ok. - if (auto loadOp = llvm::dyn_cast(op)) { - return loadOp.use_empty(); - } - // Subview is ok if it is dead itself. - if (llvm::isa(op)) { - return operationConsideredDead(op); - } - return false; - })) { - return false; - } - } - return true; - } - - void recursiveErase(mlir::Operation* op, - llvm::SmallVectorImpl* erase_list) { - for (auto result : op->getResults()) { - for (auto user : llvm::make_early_inc_range(result.getUsers())) { - recursiveErase(user, erase_list); - } - } - erase_list->push_back(op); - } - - void runOnFunction() override { - llvm::SmallVector dead_ops; - getFunction().walk([&](mlir::AllocOp allocOp) { - if (!operationConsideredDead(allocOp)) { - return; - } - - // TODO(herhut): There should be a generic helper for this. - recursiveErase(allocOp, &dead_ops); - }); - for (auto op : dead_ops) { - op->erase(); - } - } -}; - -// TODO(herhut): Move this to MLIR core. -struct MoveScalarComputationsIntoGpuLaunch - : mlir::PassWrapper { - static bool isInliningBeneficiary(mlir::Operation* op) { - return llvm::isa(op); - } - - static bool extractBeneficiaryOps( - mlir::Operation* op, llvm::SmallVectorImpl* ops, - llvm::SetVector args) { - if (!isInliningBeneficiary(op)) { - return false; - } - - ops->push_back(op); - for (auto operand : op->getOperands()) { - // It is an existing arg, keep going. - if (args.count(operand)) { - continue; - } - mlir::Operation* definingOp = operand.getDefiningOp(); - if (!definingOp || !extractBeneficiaryOps(definingOp, ops, args)) { - return false; - } - } - return true; - } - - static void inlineOperationsIntoLaunch(mlir::gpu::LaunchOp launch) { - llvm::SetVector used_above; - mlir::getUsedValuesDefinedAbove(launch.body(), used_above); - mlir::BlockAndValueMapping inlined_map; - for (mlir::Value v : used_above) { - llvm::SmallVector ops_to_move; - mlir::Operation* definingOp = v.getDefiningOp(); - if (definingOp && - extractBeneficiaryOps(definingOp, &ops_to_move, used_above)) { - mlir::OpBuilder b(launch.body()); - for (mlir::Operation* op : llvm::reverse(ops_to_move)) { - auto result = b.clone(*op, inlined_map); - for (auto pair : llvm::zip(op->getResults(), result->getResults())) { - mlir::replaceAllUsesInRegionWith(std::get<0>(pair), - std::get<1>(pair), launch.body()); - } - inlined_map.map(op->getResults(), result->getResults()); - } - } - } - } - - void runOnFunction() override { - mlir::FuncOp fun = getFunction(); - fun.walk( - [](mlir::gpu::LaunchOp launch) { inlineOperationsIntoLaunch(launch); }); - } -}; - -// Sort the operands to the kernel for a deterministic order. First operands -// that are defined by function arguments, followed by operands that are -// returned from the function. This only works for simple functions without -// control flow and can be used in cases where the kernel is extracted and used -// independently of the host-side code. -struct RewriteKernelSignature - : mlir::PassWrapper { - void runOnFunction() override { - mlir::FuncOp func = getFunction(); - mlir::ModuleOp module = func.getParentOfType(); - getFunction().walk([&](mlir::gpu::LaunchFuncOp launchOp) { - mlir::gpu::GPUFuncOp kernel = - module.lookupSymbol(launchOp.kernel()); - - if (kernel.getNumFuncArguments() != - func.getNumArguments() + func.getNumResults()) { - kernel.emitError() - << "number of kernel arguments does not match number" - << "of arguments and results of surrounding function"; - signalPassFailure(); - return; - } - if (!llvm::hasSingleElement(func)) { - func.emitError() << "surrounding function has more than one block"; - signalPassFailure(); - return; - } - - // Compute a map from function arguments to kernel function operands. - mlir::BlockAndValueMapping func_to_kernel; - for (mlir::BlockArgument arg : func.getArguments()) { - for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) { - if (launchOp.getKernelOperand(i) == arg) { - func_to_kernel.map(arg, kernel.getArgument(i)); - break; - } - } - } - // Also add function results that are computed by the launch. - mlir::Operation* returnOp = func.getBody().back().getTerminator(); - for (mlir::Value result : returnOp->getOperands()) { - for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) { - if (launchOp.getKernelOperand(i) == result) { - func_to_kernel.map(result, kernel.getArgument(i)); - break; - } - } - } - - // Create a new kernel function with modified signature. It will have the - // parameters and result types of the original funcion as its parameter - // type and otherwise will be void. - auto gpu_module = kernel.getParentOfType(); - mlir::OpBuilder kernel_builder(gpu_module.body()); - auto operand_types = llvm::to_vector<4>(llvm::concat( - func.getType().getInputs(), func.getType().getResults())); - auto new_kernel = kernel_builder.create( - kernel.getLoc(), kernel.getName(), - kernel_builder.getFunctionType(operand_types, {})); - new_kernel.setAttr(mlir::gpu::GPUDialect::getKernelFuncAttrName(), - kernel_builder.getUnitAttr()); - - // Create a map from old kernel argument to new one. - mlir::BlockAndValueMapping old_kernel_to_new; - for (int i = 0, e = func.getNumArguments(); i < e; ++i) { - mlir::Value func_arg = func.getArgument(i); - mlir::Value new_kernel_arg = new_kernel.getArgument(i); - mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(func_arg); - if (!old_kernel_arg) { - kernel.emitOpError() - << "argument " << i - << " to containing function is not an argument to the kernel"; - signalPassFailure(); - return; - } - old_kernel_to_new.map(old_kernel_arg, new_kernel_arg); - } - for (int i = 0, e = returnOp->getNumOperands(); i < e; ++i) { - mlir::Value ret_op = returnOp->getOperand(i); - mlir::Value new_kernel_arg = - new_kernel.getArgument(func.getNumArguments() + i); - mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(ret_op); - if (!old_kernel_arg) { - kernel.emitOpError() - << "result " << i - << " of containing function is not an argument to the kernel"; - signalPassFailure(); - return; - } - old_kernel_to_new.map(old_kernel_arg, new_kernel_arg); - } - // Steal the body by appending the blocks and inserting a branch. - kernel.body().cloneInto(&new_kernel.getBody(), old_kernel_to_new); - kernel_builder.setInsertionPointToEnd(&new_kernel.body().front()); - kernel_builder.create( - new_kernel.getLoc(), &*std::next(new_kernel.body().begin())); - // Now create a new launchOp calling the new kernel. We need to forward - // the arguments of the surrounding function and operands to the return. - mlir::SmallVector new_operands; - new_operands.reserve(new_kernel.getNumFuncArguments()); - new_operands.append(func.args_begin(), func.args_end()); - new_operands.append(returnOp->operand_begin(), returnOp->operand_end()); - mlir::OpBuilder launch_builder(launchOp); - launch_builder.create( - launchOp.getLoc(), new_kernel, launchOp.getGridSizeOperandValues(), - launchOp.getBlockSizeOperandValues(), new_operands); - // Launch does not have results, so we can just erase it. And the kernel - // also needs to go. - launchOp.erase(); - kernel.erase(); - }); - } -}; - -// Extract_element(mhlo_scalars_to_dimension_tensor(v_i), i) -> v_i -// -// We need to direct fusion to the inner loops. This cannot be done with -// a passmanager alone ATM, as nested pass managers require operations to -// be closed from above. -struct MapParallelLoops - : public mlir::PassWrapper { - void runOnFunction() override { - mlir::greedilyMapParallelSCFToGPU(getFunction().getBody()); - } -}; - -// We need to direct fusion to the inner loops. This cannot be done with -// a passmanager alone ATM, as nested pass managers require operations to -// be closed from above. -struct FuseInnerParallelLoops - : public mlir::PassWrapper { - void runOnFunction() override { - getFunction().walk([](mlir::scf::ParallelOp op) { - mlir::scf::naivelyFuseParallelOps(op.region()); - }); - } -}; - -// Collapse all loop dimension into the first one. -struct ParallelLoopCollapsingToFirstDim - : public mlir::PassWrapper> { - void runOnOperation() override { - mlir::Operation* module = getOperation(); - - module->walk([&](mlir::scf::ParallelOp op) { - unsigned num_loops = op.getNumLoops(); - std::vector combinedLoops; - combinedLoops.reserve(num_loops); - for (unsigned i = 0; i < num_loops; ++i) { - combinedLoops.push_back(i); - } - mlir::collapseParallelLoops(op, {combinedLoops}); - }); - } -}; -} // namespace Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { mlir::PassManager pm(module.getContext()); @@ -461,14 +68,14 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { // Moving `AllocOp`s and inserting missing `DeallocOp`s pm.addPass(::mlir::createBufferPlacementPass()); // Next, we can strip the outer fusion operation. - pm.addPass(absl::make_unique()); + pm.addPass(createFusionOpRemoverPass()); // Remove unnecessary LHLO copies. pm.addPass(::mlir::lmhlo::createLhloCopyRemovalPass()); // Transform LHLO operations to LinAlg. pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass()); // Fuse linalg operations. - pm.addPass(::mlir::lmhlo::createLhloFuseLinalg(/*use_parallel_loops=*/true, - tiling_for_unrolling)); + pm.addPass(::mlir::lmhlo::createLhloFuseLinalgPass( + /*use_parallel_loops=*/true, tiling_for_unrolling)); // Legalize reduce operations directly to GPU dialect. pm.addPass(::mlir::lmhlo::createLegalizeToGpuPass()); // Transform the Linalg operations inside of the loop nest into parallel @@ -479,26 +86,26 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); // Fuse the inner-most loops. - pm.addPass(absl::make_unique()); + pm.addPass(createFuseInnerParallelLoopsPass()); // Run CSE to ensure that loads and stores to the same subview get // recognized as such. pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); // Forward stores to buffers to loads. - pm.addPass(absl::make_unique()); + pm.addPass(createStoreForwardingPass()); // Remove now unused temporary buffers. - pm.addPass(absl::make_unique()); + pm.addPass(createDeadTempBufferRemovalPass()); if (!options.unroll_factors.empty()) { pm.addPass(::mlir::createParallelLoopTilingPass(as_int64)); } // Project all loop dimensions to X if necessary. if (options.collapse_parallel_loops) { - pm.addPass(absl::make_unique()); + pm.addPass(createParallelLoopCollapsingToFirstDimPass()); } // Some basic cleanup. pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); // Greedily map the remaining loop to GPU hardware dimensions. - pm.addPass(absl::make_unique()); + pm.addPass(createMapParallelLoopsPass()); // Apply the mapping. pm.addPass(mlir::createParallelLoopToGpuPass()); // Some basic cleanup. @@ -512,16 +119,16 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { // Approximate of requested. if (options.use_approximations) { pm.addNestedPass<::mlir::FuncOp>( - ::mlir::hlo::createLegalizeTanhToApproximationPass()); + ::mlir::mhlo::createLegalizeTanhToApproximationPass()); } // Move scalar operations into the launch to ensure smaller signatures. - pm.addPass(absl::make_unique()); + pm.addPass(createMoveScalarComputationsIntoGpuLaunchPass()); // Take launches to launches with kernels. pm.addPass(::mlir::createGpuKernelOutliningPass()); // Make sure the kernel signature resembled the original function's // signature if (options.rewrite_signature) { - pm.addPass(absl::make_unique()); + pm.addPass(createRewriteKernelSignaturePass()); } if (failed(pm.run(module))) { return InternalError("Lowering to GPU kernels failed."); @@ -595,5 +202,6 @@ StatusOr ExtractKernelModule(mlir::ModuleOp module) { }); return kernelModule; } + } // namespace mlir_gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 194eb4618d3..e0d7456fbb8 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -205,7 +205,7 @@ LhloDialectEmitter::LhloDialectEmitter( platform_(platform) { LLVMDialect* llvmDialect = mlir_module.getContext()->getRegisteredDialect(); - pointer_size_ = llvmDialect->getLLVMModule().getDataLayout().getPointerSize(); + pointer_size_ = llvmDialect->getDataLayout().getPointerSize(); } void LhloDialectEmitter::AddThunkToThunkSequence(std::unique_ptr thunk) { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index 458522f89e6..df2bd2e4c23 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -30,18 +30,14 @@ namespace { using ::mlir::MLIRContext; using ::mlir::LLVM::LLVMDialect; -int64 ConfigureLLVMModuleAndGetPointerSize(MLIRContext* context) { +int64 GetPointerSize(MLIRContext* context) { LLVMDialect* dialect = context->getRegisteredDialect(); - llvm::Module& module = dialect->getLLVMModule(); - module.setTargetTriple(gpu::nvptx::kTargetTriple); - module.setDataLayout(gpu::nvptx::kDataLayout); - return module.getDataLayout().getPointerSize(); + return dialect->getDataLayout().getPointerSize(); } } // namespace -MlirCompiler::MlirCompiler() - : pointer_size_(ConfigureLLVMModuleAndGetPointerSize(&context_)) {} +MlirCompiler::MlirCompiler() : pointer_size_(GetPointerSize(&context_)) {} se::Platform::Id MlirCompiler::PlatformId() const { return stream_executor::cuda::kCudaPlatformId; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index 2c2076bbd97..4879c6b5099 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "llvm/IR/LLVMContext.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project @@ -292,10 +293,10 @@ Status InsertBufferLoadPreduleIntoKernel( BufferAssignment* assignment, const std::vector& buffers) { mlir::OpBuilder builder(kernel.getBody()); - auto llvm_dialect = kernel.getContext()->getRegisteredDialect(); - auto offset_type = LLVMType::getInt64Ty(llvm_dialect); - auto ptr_type = LLVMType::getInt8PtrTy(llvm_dialect); - auto void_type = LLVMType::getVoidTy(llvm_dialect); + auto* context = kernel.getContext(); + auto offset_type = LLVMType::getInt64Ty(context); + auto ptr_type = LLVMType::getInt8PtrTy(context); + auto void_type = LLVMType::getVoidTy(context); auto loc = kernel.getLoc(); auto num_original_args = kernel.getNumArguments(); @@ -543,7 +544,11 @@ StatusOr> MlirCompilerImpl::RunBackend( TF_RETURN_IF_ERROR( module_hook_.invoke(IRHook::LoweringStage::KERNEL, *kernel_module)); - auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module); + // Translate to LLVM IR in a fresh context. The module is further translated + // to textual PTX and a CUBIN blob so there is no need for the context to live + // longer than this function. + llvm::LLVMContext llvmContext; + auto llvmModule = mlir::translateModuleToNVVMIR(*kernel_module, llvmContext); if (!llvmModule) { return InternalError("Translation to LLVM failed"); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.cc b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc new file mode 100644 index 00000000000..887f14e90d9 --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.cc @@ -0,0 +1,423 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/mlir_gpu/passes.h" + +#include "absl/memory/memory.h" +#include "llvm/ADT/SetVector.h" +#include "mlir/Dialect/GPU/GPUDialect.h" // from @llvm-project +#include "mlir/Dialect/GPU/ParallelLoopMapper.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Transforms/LoopUtils.h" // from @llvm-project +#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" + +namespace xla { +namespace mlir_gpu { +namespace { + +struct FusionOpRemoverPass + : public mlir::PassWrapper { + void runOnFunction() override { + getFunction().walk([&](mlir::lmhlo::FusionOp op) { + mlir::OpBuilder builder(op); + // FusionOp has a single region with a single block, so we can just walk + // over it and clone operations to the outside. + mlir::BlockAndValueMapping mapping; + for (auto& nested_op : op.region().front().without_terminator()) { + auto clone = builder.clone(nested_op, mapping); + for (auto pair : + llvm::zip(nested_op.getResults(), clone->getResults())) { + mapping.map(std::get<0>(pair), std::get<1>(pair)); + } + } + op.erase(); + }); + } +}; + +struct StoreForwardingPass + : mlir::PassWrapper { + mlir::StoreOp findStore(mlir::Operation* op, + std::function matches) { + // Search from op upwards in the current block. + mlir::Block* block = op->getBlock(); + auto startFromIt = + std::find_if(block->rbegin(), block->rend(), + [op](mlir::Operation& other) { return &other == op; }); + for (auto storeOpIt = startFromIt; storeOpIt != block->rend(); + ++storeOpIt) { + auto storeOp = llvm::dyn_cast(&*(storeOpIt)); + if (!storeOp || !matches(storeOp)) { + continue; + } + + return storeOp; + } + // No store operation found. Continue search outside of the parallel + // loop if block is in a parallel loop. + if (auto parallelOp = + llvm::dyn_cast(block->getParentOp())) { + return findStore(parallelOp.getOperation(), matches); + } + return {}; + } + + // Recursively search defining ops for AllocOp. Return either AllocOp if it is + // found or nullptr. + mlir::Operation* SearchAllocOp(mlir::Value memref) { + mlir::Operation* defOp = memref.getDefiningOp(); + while (auto subviewOp = mlir::dyn_cast_or_null(defOp)) { + defOp = subviewOp.source().getDefiningOp(); + } + if (auto allocOp = mlir::dyn_cast_or_null(defOp)) { + return allocOp.getOperation(); + } + return nullptr; + } + + // Retrieves AllocOp from the cache or actually looks for it. + mlir::Operation* GetAllocOp( + mlir::Value memref, + llvm::DenseMap* memrefToAllocOp) { + auto allocOpIt = memrefToAllocOp->find(memref); + if (allocOpIt != memrefToAllocOp->end()) { + return allocOpIt->second; + } + auto allocOp = SearchAllocOp(memref); + memrefToAllocOp->insert({memref, allocOp}); + return allocOp; + } + + void runOnFunction() override { + llvm::DenseMap memrefToAllocOp; + + getFunction().walk([&](mlir::LoadOp loadOp) { + auto storeOp = findStore(loadOp, [&](mlir::StoreOp storeOp) { + mlir::Operation* storeOpAlloc = + GetAllocOp(storeOp.memref(), &memrefToAllocOp); + mlir::Operation* loadOpAlloc = + GetAllocOp(loadOp.memref(), &memrefToAllocOp); + return storeOpAlloc && loadOpAlloc && (storeOpAlloc == loadOpAlloc); + }); + if (!storeOp) { + return; + } + auto storeIndices = storeOp.getIndices(); + auto loadIndices = loadOp.getIndices(); + if (!std::equal(storeIndices.begin(), storeIndices.end(), + loadIndices.begin(), loadIndices.end())) { + return; + } + loadOp.replaceAllUsesWith(storeOp.getValueToStore()); + loadOp.erase(); + }); + } +}; + +struct DeadTempBufferRemovalPass + : mlir::PassWrapper { + bool operationConsideredDead(mlir::Operation* op) { + for (auto result : op->getResults()) { + if (!llvm::all_of(result.getUsers(), [&](mlir::Operation* op) { + // Store and Dealloc is OK. + if (llvm::isa(op)) { + return true; + } + // Load without uses is also ok. + if (auto loadOp = llvm::dyn_cast(op)) { + return loadOp.use_empty(); + } + // Subview is ok if it is dead itself. + if (llvm::isa(op)) { + return operationConsideredDead(op); + } + return false; + })) { + return false; + } + } + return true; + } + + void recursiveErase(mlir::Operation* op, + llvm::SmallVectorImpl* erase_list) { + for (auto result : op->getResults()) { + for (auto user : llvm::make_early_inc_range(result.getUsers())) { + recursiveErase(user, erase_list); + } + } + erase_list->push_back(op); + } + + void runOnFunction() override { + llvm::SmallVector dead_ops; + getFunction().walk([&](mlir::AllocOp allocOp) { + if (!operationConsideredDead(allocOp)) { + return; + } + + // TODO(herhut): There should be a generic helper for this. + recursiveErase(allocOp, &dead_ops); + }); + for (auto op : dead_ops) { + op->erase(); + } + } +}; + +struct MoveScalarComputationsIntoGpuLaunchPass + : mlir::PassWrapper { + static bool isInliningBeneficiary(mlir::Operation* op) { + return llvm::isa(op); + } + + static bool extractBeneficiaryOps( + mlir::Operation* op, llvm::SmallVectorImpl* ops, + llvm::SetVector args) { + if (!isInliningBeneficiary(op)) { + return false; + } + + ops->push_back(op); + for (auto operand : op->getOperands()) { + // It is an existing arg, keep going. + if (args.count(operand)) { + continue; + } + mlir::Operation* definingOp = operand.getDefiningOp(); + if (!definingOp || !extractBeneficiaryOps(definingOp, ops, args)) { + return false; + } + } + return true; + } + + static void inlineOperationsIntoLaunch(mlir::gpu::LaunchOp launch) { + llvm::SetVector used_above; + mlir::getUsedValuesDefinedAbove(launch.body(), used_above); + mlir::BlockAndValueMapping inlined_map; + for (mlir::Value v : used_above) { + llvm::SmallVector ops_to_move; + mlir::Operation* definingOp = v.getDefiningOp(); + if (definingOp && + extractBeneficiaryOps(definingOp, &ops_to_move, used_above)) { + mlir::OpBuilder b(launch.body()); + for (mlir::Operation* op : llvm::reverse(ops_to_move)) { + auto result = b.clone(*op, inlined_map); + for (auto pair : llvm::zip(op->getResults(), result->getResults())) { + mlir::replaceAllUsesInRegionWith(std::get<0>(pair), + std::get<1>(pair), launch.body()); + } + inlined_map.map(op->getResults(), result->getResults()); + } + } + } + } + + void runOnFunction() override { + mlir::FuncOp fun = getFunction(); + fun.walk( + [](mlir::gpu::LaunchOp launch) { inlineOperationsIntoLaunch(launch); }); + } +}; + +struct RewriteKernelSignaturePass + : mlir::PassWrapper { + void runOnFunction() override { + mlir::FuncOp func = getFunction(); + mlir::ModuleOp module = func.getParentOfType(); + getFunction().walk([&](mlir::gpu::LaunchFuncOp launchOp) { + mlir::gpu::GPUFuncOp kernel = + module.lookupSymbol(launchOp.kernel()); + + if (kernel.getNumFuncArguments() != + func.getNumArguments() + func.getNumResults()) { + kernel.emitError() + << "number of kernel arguments does not match number" + << "of arguments and results of surrounding function"; + signalPassFailure(); + return; + } + if (!llvm::hasSingleElement(func)) { + func.emitError() << "surrounding function has more than one block"; + signalPassFailure(); + return; + } + + // Compute a map from function arguments to kernel function operands. + mlir::BlockAndValueMapping func_to_kernel; + for (mlir::BlockArgument arg : func.getArguments()) { + for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) { + if (launchOp.getKernelOperand(i) == arg) { + func_to_kernel.map(arg, kernel.getArgument(i)); + break; + } + } + } + // Also add function results that are computed by the launch. + mlir::Operation* returnOp = func.getBody().back().getTerminator(); + for (mlir::Value result : returnOp->getOperands()) { + for (int i = 0, e = launchOp.getNumKernelOperands(); i < e; ++i) { + if (launchOp.getKernelOperand(i) == result) { + func_to_kernel.map(result, kernel.getArgument(i)); + break; + } + } + } + + // Create a new kernel function with modified signature. It will have the + // parameters and result types of the original funcion as its parameter + // type and otherwise will be void. + auto gpu_module = kernel.getParentOfType(); + mlir::OpBuilder kernel_builder(gpu_module.body()); + auto operand_types = llvm::to_vector<4>(llvm::concat( + func.getType().getInputs(), func.getType().getResults())); + auto new_kernel = kernel_builder.create( + kernel.getLoc(), kernel.getName(), + kernel_builder.getFunctionType(operand_types, {})); + new_kernel.setAttr(mlir::gpu::GPUDialect::getKernelFuncAttrName(), + kernel_builder.getUnitAttr()); + + // Create a map from old kernel argument to new one. + mlir::BlockAndValueMapping old_kernel_to_new; + for (int i = 0, e = func.getNumArguments(); i < e; ++i) { + mlir::Value func_arg = func.getArgument(i); + mlir::Value new_kernel_arg = new_kernel.getArgument(i); + mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(func_arg); + if (!old_kernel_arg) { + kernel.emitOpError() + << "argument " << i + << " to containing function is not an argument to the kernel"; + signalPassFailure(); + return; + } + old_kernel_to_new.map(old_kernel_arg, new_kernel_arg); + } + for (int i = 0, e = returnOp->getNumOperands(); i < e; ++i) { + mlir::Value ret_op = returnOp->getOperand(i); + mlir::Value new_kernel_arg = + new_kernel.getArgument(func.getNumArguments() + i); + mlir::Value old_kernel_arg = func_to_kernel.lookupOrNull(ret_op); + if (!old_kernel_arg) { + kernel.emitOpError() + << "result " << i + << " of containing function is not an argument to the kernel"; + signalPassFailure(); + return; + } + old_kernel_to_new.map(old_kernel_arg, new_kernel_arg); + } + // Steal the body by appending the blocks and inserting a branch. + kernel.body().cloneInto(&new_kernel.getBody(), old_kernel_to_new); + kernel_builder.setInsertionPointToEnd(&new_kernel.body().front()); + kernel_builder.create( + new_kernel.getLoc(), &*std::next(new_kernel.body().begin())); + // Now create a new launchOp calling the new kernel. We need to forward + // the arguments of the surrounding function and operands to the return. + mlir::SmallVector new_operands; + new_operands.reserve(new_kernel.getNumFuncArguments()); + new_operands.append(func.args_begin(), func.args_end()); + new_operands.append(returnOp->operand_begin(), returnOp->operand_end()); + mlir::OpBuilder launch_builder(launchOp); + launch_builder.create( + launchOp.getLoc(), new_kernel, launchOp.getGridSizeOperandValues(), + launchOp.getBlockSizeOperandValues(), new_operands); + // Launch does not have results, so we can just erase it. And the kernel + // also needs to go. + launchOp.erase(); + kernel.erase(); + }); + } +}; + +struct MapParallelLoopsPass + : public mlir::PassWrapper { + void runOnFunction() override { + mlir::greedilyMapParallelSCFToGPU(getFunction().getBody()); + } +}; + +struct FuseInnerParallelLoopsPass + : public mlir::PassWrapper { + void runOnFunction() override { + getFunction().walk([](mlir::scf::ParallelOp op) { + mlir::scf::naivelyFuseParallelOps(op.region()); + }); + } +}; + +struct ParallelLoopCollapsingToFirstDimPass + : public mlir::PassWrapper> { + void runOnOperation() override { + mlir::Operation* module = getOperation(); + + module->walk([&](mlir::scf::ParallelOp op) { + unsigned num_loops = op.getNumLoops(); + std::vector combinedLoops; + combinedLoops.reserve(num_loops); + for (unsigned i = 0; i < num_loops; ++i) { + combinedLoops.push_back(i); + } + mlir::collapseParallelLoops(op, {combinedLoops}); + }); + } +}; + +} // namespace + +std::unique_ptr createFusionOpRemoverPass() { + return absl::make_unique(); +} + +std::unique_ptr createStoreForwardingPass() { + return absl::make_unique(); +} + +std::unique_ptr createDeadTempBufferRemovalPass() { + return absl::make_unique(); +} + +std::unique_ptr +createMoveScalarComputationsIntoGpuLaunchPass() { + return absl::make_unique(); +} + +std::unique_ptr createRewriteKernelSignaturePass() { + return absl::make_unique(); +} + +std::unique_ptr createFuseInnerParallelLoopsPass() { + return absl::make_unique(); +} + +std::unique_ptr createMapParallelLoopsPass() { + return absl::make_unique(); +} + +std::unique_ptr> +createParallelLoopCollapsingToFirstDimPass() { + return absl::make_unique(); +} + +} // namespace mlir_gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/passes.h b/tensorflow/compiler/xla/service/mlir_gpu/passes.h new file mode 100644 index 00000000000..e3840628a2e --- /dev/null +++ b/tensorflow/compiler/xla/service/mlir_gpu/passes.h @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace xla { +namespace mlir_gpu { + +// TODO(herhut, pifon): Move these passes to MLIR Core. + +/// Replaces a FusionOp by the operations contained in its region. +std::unique_ptr createFusionOpRemoverPass(); + +/// Replaces a load that immediately follows a store to the same address with +/// the stored value. This needs generalization. +std::unique_ptr createStoreForwardingPass(); + +/// Removes temporary buffers that are only written to but never read from or +/// that are read but the read value is not used. Needs an analysis that proves +/// that loads and stores are side-effect free (in bounds, no aliasing, etc.). +std::unique_ptr createDeadTempBufferRemovalPass(); + +/// Moves scalar computations to the GPULaunchOp body. +std::unique_ptr +createMoveScalarComputationsIntoGpuLaunchPass(); + +/// Sorts the operands to the kernel for a deterministic order. First operands +/// that are defined by function arguments, followed by operands that are +/// returned from the function. This only works for simple functions without +/// control flow and can be used in cases where the kernel is extracted and used +/// independently of the host-side code. +std::unique_ptr createRewriteKernelSignaturePass(); + +/// We need to direct fusion to the inner loops. This cannot be done with +/// a passmanager alone ATM, as nested pass managers require operations to +/// be closed from above. +std::unique_ptr createFuseInnerParallelLoopsPass(); + +/// Greedily maps loops to GPU hardware dimensions. +std::unique_ptr createMapParallelLoopsPass(); + +/// Collapses all loop dimension into the first one. +std::unique_ptr> +createParallelLoopCollapsingToFirstDimPass(); + +} // namespace mlir_gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_PASSES_H_ diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo index 953eb2022f8..8d7930ea8c0 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_as_kernel.hlo @@ -7,24 +7,24 @@ ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { ROOT %add = f32[2,2]{1,0} add(f32[2,2]{1,0} %x, f32[2,2]{1,0} %y) } -// CHECK: func @add_kernel(%[[ARG0:.*]]: [[TYPE:!llvm<.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]] +// CHECK: func @add_kernel(%[[ARG0:.*]]: [[TYPE:!llvm\..*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]] // // Check that relevant sizes and strides are emitted. // -// CHECK: %[[CAST0:.*]] = llvm.bitcast %[[ARG0:.*]] : !llvm<"i8*"> to !llvm<"float*"> +// CHECK: %[[CAST0:.*]] = llvm.bitcast %[[ARG0:.*]] : !llvm.ptr to !llvm.ptr // CHECK: %[[SIZE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 // CHECK: %[[SIZE01:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 // CHECK: %[[STRIDE01:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 // CHECK: %[[STRIDE00:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 -// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG1:.*]] : !llvm<"i8*"> to !llvm<"float*"> +// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG1:.*]] : !llvm.ptr to !llvm.ptr // CHECK: %[[SIZE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 // CHECK: %[[SIZE11:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 // CHECK: %[[STRIDE11:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 // CHECK: %[[STRIDE10:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 -// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[ARG2:.*]] : !llvm<"i8*"> to !llvm<"float*"> +// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[ARG2:.*]] : !llvm.ptr to !llvm.ptr // CHECK: %[[SIZE20:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 // CHECK: %[[SIZE21:.*]] = llvm.mlir.constant(2 : i64) : !llvm.i64 // CHECK: %[[STRIDE21:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i64 @@ -34,30 +34,30 @@ ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { // Check that the emitted sizes and strides, as well the pointers to HLO buffers, // are inserted into the memref descriptors. // -// CHECK: %[[DESC0:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC01:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC0]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC02:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC01]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC03:.*]] = llvm.insertvalue %{{.*}}, %[[DESC02]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC04:.*]] = llvm.insertvalue %[[SIZE00]], %[[DESC03]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC05:.*]] = llvm.insertvalue %[[STRIDE00]], %[[DESC04]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC06:.*]] = llvm.insertvalue %[[SIZE01]], %[[DESC05]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE01]], %[[DESC06]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC01:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC0]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC02:.*]] = llvm.insertvalue %[[CAST0]], %[[DESC01]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC03:.*]] = llvm.insertvalue %{{.*}}, %[[DESC02]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC04:.*]] = llvm.insertvalue %[[SIZE00]], %[[DESC03]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC05:.*]] = llvm.insertvalue %[[STRIDE00]], %[[DESC04]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC06:.*]] = llvm.insertvalue %[[SIZE01]], %[[DESC05]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE01]], %[[DESC06]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: %[[DESC1:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC1]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC12:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC11]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC13:.*]] = llvm.insertvalue %{{.*}}, %[[DESC12]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC14:.*]] = llvm.insertvalue %[[SIZE10]], %[[DESC13]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC15:.*]] = llvm.insertvalue %[[STRIDE10]], %[[DESC14]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC16:.*]] = llvm.insertvalue %[[SIZE11]], %[[DESC15]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE11]], %[[DESC16]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC11:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC1]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC12:.*]] = llvm.insertvalue %[[CAST1]], %[[DESC11]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC13:.*]] = llvm.insertvalue %{{.*}}, %[[DESC12]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC14:.*]] = llvm.insertvalue %[[SIZE10]], %[[DESC13]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC15:.*]] = llvm.insertvalue %[[STRIDE10]], %[[DESC14]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC16:.*]] = llvm.insertvalue %[[SIZE11]], %[[DESC15]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE11]], %[[DESC16]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: %[[DESC2:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC21:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC2]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC22:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC21]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC23:.*]] = llvm.insertvalue %{{.*}}, %[[DESC22]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC24:.*]] = llvm.insertvalue %[[SIZE20]], %[[DESC23]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC25:.*]] = llvm.insertvalue %[[STRIDE20]], %[[DESC24]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %[[DESC26:.*]] = llvm.insertvalue %[[SIZE21]], %[[DESC25]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> -// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE21]], %[[DESC26]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> +// CHECK: %[[DESC2:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC21:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC2]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC22:.*]] = llvm.insertvalue %[[CAST2]], %[[DESC21]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC23:.*]] = llvm.insertvalue %{{.*}}, %[[DESC22]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC24:.*]] = llvm.insertvalue %[[SIZE20]], %[[DESC23]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC25:.*]] = llvm.insertvalue %[[STRIDE20]], %[[DESC24]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[DESC26:.*]] = llvm.insertvalue %[[SIZE21]], %[[DESC25]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %{{.*}} = llvm.insertvalue %[[STRIDE21]], %[[DESC26]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo index 3a3dd22b338..8656b4edeb7 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo @@ -6,7 +6,7 @@ ENTRY %CopyTranspose (x: f32[2,4]) -> f32[2,4]{0,1} { ROOT %copy = f32[2,4]{0,1} copy(f32[2,4] %x) } -// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> +// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1 * 2)> // CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, // CHECK-SAME: %[[RESULT:.*]]: memref<2x4xf32, #[[MAP0]]>) // CHECK: "lmhlo.copy"(%[[OPERAND]], %[[RESULT]]) diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 2ed5e709d81..bc79f16db2a 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -270,12 +270,13 @@ StatusOr> Service::CreateModuleConfig( auto config = absl::make_unique(program_shape); ComputationLayout* computation_layout = config->mutable_entry_computation_layout(); - if (program_shape.parameters_size() != argument_shapes.size()) { + const int64 argument_shapes_size = argument_shapes.size(); + if (program_shape.parameters_size() != argument_shapes_size) { return InvalidArgument("computation takes %d parameters, but %u given", program_shape.parameters_size(), argument_shapes.size()); } - for (int i = 0; i < argument_shapes.size(); ++i) { + for (int i = 0, end = argument_shapes.size(); i < end; ++i) { // Verify that shape of arguments matches the shape of the arguments in the // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], @@ -315,6 +316,7 @@ StatusOr> Service::CreateModuleConfig( } config->set_use_spmd_partitioning( execution_options->use_spmd_partitioning()); + config->set_deduplicate_hlo(execution_options->deduplicate_hlo()); config->set_seed(execution_options->seed()); config->set_launch_id(execution_options->launch_id()); config->set_debug_options(execution_options->debug_options()); @@ -371,7 +373,7 @@ StatusOr>> Service::BuildExecutables( // Dump computation proto state if flag is set. std::vector> hlo_protos; - for (int64 i = 0; i < module_protos.size(); ++i) { + for (int64 i = 0, end = module_protos.size(); i < end; ++i) { auto hlo_proto = absl::make_unique(); *hlo_proto->mutable_hlo_module() = *module_protos[i]; hlo_protos.push_back(std::move(hlo_proto)); @@ -385,7 +387,7 @@ StatusOr>> Service::BuildExecutables( CHECK_EQ(module_protos.size(), module_configs.size()); auto module_group = absl::make_unique(module_protos[0]->name()); - for (int64 i = 0; i < module_protos.size(); ++i) { + for (int64 i = 0, end = module_protos.size(); i < end; ++i) { const HloModuleProto* proto = module_protos[i]; const HloModuleConfig& config = *module_configs[i]; TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config)); @@ -433,12 +435,12 @@ Service::ExecuteParallelAndRegisterResult( for (int64 i = 0; i < executables.size(); i++) { TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); CHECK_EQ(replicas.size(), arguments[i].size()); - for (int64 replica = 0; replica < replicas.size(); ++replica) { + for (int64 replica = 0, end = replicas.size(); replica < end; ++replica) { device_assignment(replica, i) = replicas[replica]->device_ordinal(); } } - for (int64 i = 0; i < executables.size(); i++) { + for (int64 i = 0, end = executables.size(); i < end; i++) { // Stream executors for the replicas of the current computation. TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); CHECK_EQ(replicas.size(), arguments[i].size()); @@ -497,7 +499,7 @@ Service::ExecuteParallelAndRegisterResult( } // Wait for all executions to complete. - for (int64 i = 0; i < streams.size(); ++i) { + for (int64 i = 0, end = streams.size(); i < end; ++i) { Status block_status = streams[i]->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("failed to complete execution for stream %d: %s", i, @@ -715,7 +717,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, std::vector snapshots; snapshots.resize(executable_ptrs.size()); - for (int i = 0; i < executable_ptrs.size(); i++) { + for (int i = 0, end = executable_ptrs.size(); i < end; i++) { if (executable_ptrs[i]->dumping_snapshot()) { *snapshots[i].mutable_hlo() = *executable_ptrs[i]->hlo_proto(); TF_ASSIGN_OR_RETURN(auto stream, @@ -761,7 +763,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, *result->add_responses() = response; } - for (int i = 0; i < executable_ptrs.size(); i++) { + for (int i = 0, end = executable_ptrs.size(); i < end; i++) { Executable* executable = executable_ptrs[i]; if (executable->dumping_snapshot()) { TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index ec8e4d23d21..8e39e32e4c3 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -643,11 +643,6 @@ Status ValidateDotDimensionNumbers( return InvalidArgument("%s", message); }; - // Check if both element types are the same. - if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return fail("Element types do not match."); - } - // Validate basic properties of dot dimension numbers. TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); @@ -954,18 +949,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR(ExpectArray( rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode)))); switch (opcode) { + case HloOpcode::kAdd: case HloOpcode::kMaximum: case HloOpcode::kMinimum: + case HloOpcode::kMultiply: return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); case HloOpcode::kSubtract: - case HloOpcode::kAdd: case HloOpcode::kAtan2: case HloOpcode::kPower: case HloOpcode::kDivide: case HloOpcode::kRemainder: - case HloOpcode::kMultiply: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -1621,11 +1616,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, batch_group_count, feature_group_count); } - if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return InvalidArgument( - "Convolution with different element types: %s and %s.", - ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); - } if (dnums.input_spatial_dimensions_size() != dnums.kernel_spatial_dimensions_size()) { return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index 6c4cf2d7866..4ff492047a3 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -91,9 +91,7 @@ bool IsShardingMoreSpecific(const HloSharding& lhs, const HloSharding& rhs) { return is_better; } if (!rhs.IsTileMaximal()) { - // If we already have a non-tile-maximal sharding then we can't improve - // that. - return false; + return lhs.NumTiles() > rhs.NumTiles(); } else if (!rhs.IsReplicated()) { // If we are not replicated then only tiled (not tile maximal) shardings // can improve us. @@ -124,9 +122,12 @@ HloSharding MergeForMoreSpecificSharding(const HloSharding& a, // Updates the sharding of the specified instruction with the specified sharding // if it is better than the current one and returns true if a new sharding have -// been applied. +// been applied. If may_combine_partial_sharding is true, this may combine the +// new and existing sharding if they are both partial tiling partial +// replication. bool MaybeImproveInstructionSharding(const HloSharding& sharding, - HloInstruction* instruction) { + HloInstruction* instruction, + bool may_combine_partial_sharding) { // We don't want to propagate tile maximal shardings. if (!IsSpatiallyPartitioned(sharding)) { return false; @@ -136,6 +137,101 @@ bool MaybeImproveInstructionSharding(const HloSharding& sharding, instruction->set_sharding(sharding); return true; } + if (may_combine_partial_sharding && sharding.ReplicateOnLastTileDim() && + instruction->sharding().ReplicateOnLastTileDim()) { + if (sharding.tile_assignment().num_elements() == + instruction->sharding().tile_assignment().num_elements()) { + // Combine the tile dimension sizes from new and old. + int64 num_devices = sharding.tile_assignment().num_elements(); + std::vector new_tile_dims; + bool compatible = true; + new_tile_dims.reserve(sharding.tile_assignment().num_dimensions()); + for (int64 i = 0; i < sharding.tile_assignment().num_dimensions() - 1; + ++i) { + int64 new_dim = sharding.tile_assignment().dim(i); + int64 old_dim = instruction->sharding().tile_assignment().dim(i); + if (new_dim == 1) { + new_tile_dims.push_back(old_dim); + } else if (old_dim == 1) { + new_tile_dims.push_back(new_dim); + } else if (new_dim == old_dim) { + new_tile_dims.push_back(new_dim); + } else { + compatible = false; + break; + } + } + int64 replication = num_devices / Product(new_tile_dims); + if (compatible && num_devices % Product(new_tile_dims) == 0 && + replication < + instruction->sharding().tile_assignment().dimensions().back()) { + new_tile_dims.push_back(replication); + Array new_tile(new_tile_dims); + // Maps from replication group ID to sorted members. + absl::flat_hash_map> old_group_members; + absl::flat_hash_map> new_group_members; + auto get_group_index = [&](absl::Span tile_indices, + const HloSharding& sharding) { + int64 group_id = 0; + for (int64 i = 0; i < tile_indices.size() - 1; ++i) { + group_id *= sharding.tile_assignment().dim(i); + group_id += tile_indices[i]; + } + return group_id; + }; + instruction->sharding().tile_assignment().Each( + [&](absl::Span indices, int64 device) { + old_group_members[get_group_index(indices, + instruction->sharding())] + .insert(device); + }); + sharding.tile_assignment().Each([&](absl::Span indices, + int64 device) { + new_group_members[get_group_index(indices, sharding)].insert(device); + }); + // Try to find the intersection of old and new replication groups, in + // order to determine the merged tile assignment. + new_tile.Each([&](absl::Span indices, int64* device) { + if (!compatible) { + return; + } + std::vector old_index(indices.begin(), indices.end()); + std::vector new_index = old_index; + for (int64 i = 0; i < indices.size() - 1; ++i) { + if (instruction->sharding().tile_assignment().dim(i) == 1) { + old_index[i] = 0; + } + if (sharding.tile_assignment().dim(i) == 1) { + new_index[i] = 0; + } + } + int64 old_group_id = + get_group_index(old_index, instruction->sharding()); + int64 new_group_id = get_group_index(new_index, sharding); + if (old_group_members[old_group_id].empty() || + new_group_members[new_group_id].empty() || + *old_group_members[old_group_id].begin() != + *new_group_members[new_group_id].begin()) { + compatible = false; + return; + } + *device = *old_group_members[old_group_id].begin(); + old_group_members[old_group_id].erase(*device); + new_group_members[new_group_id].erase(*device); + }); + if (compatible) { + if (replication == 1) { + new_tile_dims.pop_back(); + new_tile.Reshape(new_tile_dims); + instruction->set_sharding(HloSharding::Tile(new_tile)); + } else { + instruction->set_sharding(HloSharding::PartialTile(new_tile)); + } + return true; + } + } + } + } if (IsShardingMoreSpecific(sharding, instruction->sharding())) { instruction->set_sharding(sharding); return true; @@ -363,7 +459,8 @@ bool SupportSpatialPartitioning(const HloInstruction* instruction, // Convolution handling for InferShardingFromOperands(). bool InferConvolutionShardingFromOperands(HloInstruction* instruction, - bool aggressive_prop) { + bool aggressive_prop, + bool may_combine_partial_sharding) { const auto& dnums = instruction->convolution_dimension_numbers(); const HloInstruction* lhs = instruction->operand(0); const HloInstruction* rhs = instruction->operand(1); @@ -430,13 +527,15 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, partitioned_only_along_non_trivial_dims(lhs->sharding(), dot_dims->batch_dims, 0)) { return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(), - instruction); + instruction, + may_combine_partial_sharding); } if (IsSpatiallyPartitioned(rhs) && partitioned_only_along_non_trivial_dims(rhs->sharding(), dot_dims->batch_dims, 1)) { return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_rhs(), - instruction); + instruction, + may_combine_partial_sharding); } if (aggressive_prop) { // If LHS/RHS is partitioned only along the non-contracting @@ -455,19 +554,23 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, if (Product(lhs->shape().dimensions()) >= Product(rhs->shape().dimensions())) { return MaybeImproveInstructionSharding( - get_tiled_sharding_based_on_lhs(), instruction); + get_tiled_sharding_based_on_lhs(), instruction, + may_combine_partial_sharding); } else { return MaybeImproveInstructionSharding( - get_tiled_sharding_based_on_rhs(), instruction); + get_tiled_sharding_based_on_rhs(), instruction, + may_combine_partial_sharding); } } if (can_propagate_from_lhs) { return MaybeImproveInstructionSharding( - get_tiled_sharding_based_on_lhs(), instruction); + get_tiled_sharding_based_on_lhs(), instruction, + may_combine_partial_sharding); } if (can_propagate_from_rhs) { return MaybeImproveInstructionSharding( - get_tiled_sharding_based_on_rhs(), instruction); + get_tiled_sharding_based_on_rhs(), instruction, + may_combine_partial_sharding); } } } @@ -476,8 +579,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, return false; } if (lhs->sharding().IsReplicated()) { - return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); + return MaybeImproveInstructionSharding( + HloSharding::Replicate(), instruction, may_combine_partial_sharding); } if (IsConvolutionKernelSmall(instruction)) { @@ -488,11 +591,13 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, return false; } return MaybeImproveInstructionSharding(get_tiled_sharding_based_on_lhs(), - instruction); + instruction, + may_combine_partial_sharding); } // If the kernel is large (e.g backward convolution) then we only support // replicated output. - return MaybeImproveInstructionSharding(HloSharding::Replicate(), instruction); + return MaybeImproveInstructionSharding(HloSharding::Replicate(), instruction, + may_combine_partial_sharding); } // Tries to update the sharding of the specified instruction based on its @@ -512,8 +617,9 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (absl::c_any_of(instruction->operands(), [](const HloInstruction* op) { return op->has_sharding() && op->sharding().IsReplicated(); })) { - return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); + return MaybeImproveInstructionSharding( + HloSharding::Replicate(), instruction, + /*may_combine_partial_sharding=*/is_spmd); } return false; } @@ -526,7 +632,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, } HloSharding new_sharding = operand->sharding().GetSubSharding( operand->shape(), {instruction->tuple_index()}); - return MaybeImproveInstructionSharding(new_sharding, instruction); + return MaybeImproveInstructionSharding( + new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); } case HloOpcode::kTuple: { if (absl::c_none_of(instruction->operands(), @@ -599,40 +706,37 @@ bool InferShardingFromOperands(HloInstruction* instruction, sharding); return HloSharding::Tuple(instruction->shape(), tuple); }; - if (operand->sharding().IsReplicated()) { + if (operand->sharding().IsReplicated() || + (!is_spmd && + absl::c_any_of(instruction->dimensions(), [operand](int64 dim) { + return operand->sharding().tile_assignment().dim(dim) > 1; + }))) { + // We are reducing along one of the sharded dimensions. We only + // support this in SPMD. changed |= MaybeImproveInstructionSharding( - get_maybe_tuple_sharding(HloSharding::Replicate()), instruction); + get_maybe_tuple_sharding(HloSharding::Replicate()), instruction, + /*may_combine_partial_sharding=*/is_spmd); continue; } - if (absl::c_any_of(instruction->dimensions(), [operand](int64 dim) { - return operand->sharding().tile_assignment().dim(dim) > 1; - })) { - // We are reducing along one of the sharded dimensions. We don't - // support tiled sharding in this case. + auto after_partial_replication = + operand->sharding().IsReplicated() + ? operand->sharding() + : hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + operand->sharding(), instruction->dimensions()); + if (after_partial_replication.IsReplicated()) { changed |= MaybeImproveInstructionSharding( - get_maybe_tuple_sharding(HloSharding::Replicate()), instruction); - } else { - // We are reducing along some of the non-sharded dimensions. The - // result sharding should be the same as the operand sharding with the - // reduction dimensions removed as they are removed from the result - // shape. - std::vector target_tile_assignment_dimensions; - const auto& dimensions = instruction->dimensions(); - for (int64 i = 0; i < operand->shape().rank(); ++i) { - if (absl::c_find(dimensions, i) == dimensions.end()) { - target_tile_assignment_dimensions.push_back( - operand->sharding().tile_assignment().dim(i)); - } - } - Array new_tile_assignment = - operand->sharding().tile_assignment(); - new_tile_assignment.Reshape(target_tile_assignment_dimensions); - // Use the same sharding for all tuple elements, because they are part - // of the same reduce instruction. - HloSharding new_sharding = - get_maybe_tuple_sharding(HloSharding::Tile(new_tile_assignment)); - changed |= MaybeImproveInstructionSharding(new_sharding, instruction); + get_maybe_tuple_sharding(HloSharding::Replicate()), instruction, + /*may_combine_partial_sharding=*/is_spmd); + continue; } + // Use the same sharding for all tuple elements, because they are part + // of the same reduce instruction. + HloSharding new_sharding = + get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions( + after_partial_replication, instruction->dimensions())); + changed |= MaybeImproveInstructionSharding( + new_sharding, instruction, + /*may_combine_partial_sharding=*/is_spmd); } return changed; } @@ -662,13 +766,23 @@ bool InferShardingFromOperands(HloInstruction* instruction, op->sharding().tile_assignment().dim(source_dim)); } } + if (op->sharding().ReplicateOnLastTileDim()) { + target_tile_assignment_dimensions.push_back( + op->sharding().tile_assignment().dimensions().back()); + } Array new_tile_assignment = op->sharding().tile_assignment(); new_tile_assignment.Reshape(target_tile_assignment_dimensions); - HloSharding new_sharding = HloSharding::Tile(new_tile_assignment); - return MaybeImproveInstructionSharding(new_sharding, instruction); + HloSharding new_sharding = + op->sharding().ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); + return MaybeImproveInstructionSharding( + new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); } case HloOpcode::kConvolution: - return InferConvolutionShardingFromOperands(instruction, aggressive_prop); + return InferConvolutionShardingFromOperands( + instruction, aggressive_prop, + /*may_combine_partial_sharding=*/is_spmd); case HloOpcode::kTranspose: { const HloInstruction* input = instruction->operand(0); if (!IsSpatiallyPartitioned(input)) { @@ -676,7 +790,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, } HloSharding sharding = hlo_sharding_util::TransposeSharding( input->sharding(), instruction->dimensions()); - return MaybeImproveInstructionSharding(sharding, instruction); + return MaybeImproveInstructionSharding( + sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); } case HloOpcode::kReduceWindow: { const HloInstruction* lhs = instruction->operand(0); @@ -694,7 +809,9 @@ bool InferShardingFromOperands(HloInstruction* instruction, << instruction->ToString(); return false; } - return MaybeImproveInstructionSharding(lhs->sharding(), instruction); + return MaybeImproveInstructionSharding( + lhs->sharding(), instruction, + /*may_combine_partial_sharding=*/is_spmd); } case HloOpcode::kSelectAndScatter: { // Shard according to first operand, as output keeps the same shape. @@ -713,7 +830,9 @@ bool InferShardingFromOperands(HloInstruction* instruction, << instruction->ToString(); return false; } - return MaybeImproveInstructionSharding(lhs->sharding(), instruction); + return MaybeImproveInstructionSharding( + lhs->sharding(), instruction, + /*may_combine_partial_sharding=*/is_spmd); } case HloOpcode::kReshape: { if (!IsSpatiallyPartitioned(instruction->operand(0))) { @@ -724,8 +843,9 @@ bool InferShardingFromOperands(HloInstruction* instruction, instruction->operand(0)->shape(), instruction->shape(), instruction->operand(0)->sharding()); if (new_sharding.has_value()) { - return MaybeImproveInstructionSharding(new_sharding.value(), - instruction); + return MaybeImproveInstructionSharding( + new_sharding.value(), instruction, + /*may_combine_partial_sharding=*/is_spmd); } return false; } @@ -736,7 +856,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, return MaybeImproveInstructionSharding( hlo_sharding_util::ReverseSharding( instruction->operand(0)->sharding(), instruction->dimensions()), - instruction); + instruction, /*may_combine_partial_sharding=*/is_spmd); } case HloOpcode::kDot: { auto& dot_dim_numbs = instruction->dot_dimension_numbers(); @@ -765,8 +885,9 @@ bool InferShardingFromOperands(HloInstruction* instruction, } else if (ops_sharding[0]->IsReplicated() && ops_sharding[1]->IsReplicated()) { // Both replicated -> replicate - return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); + return MaybeImproveInstructionSharding( + HloSharding::Replicate(), instruction, + /*may_combine_partial_sharding=*/is_spmd); } else if (!ops_sharding[0]->IsReplicated() && !ops_sharding[1]->IsReplicated()) { // Both tile sharded. The dot spatial partitioning implementation @@ -785,8 +906,9 @@ bool InferShardingFromOperands(HloInstruction* instruction, } if (ops_sharding[representative_op]->IsReplicated()) { - return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); + return MaybeImproveInstructionSharding( + HloSharding::Replicate(), instruction, + /*may_combine_partial_sharding=*/is_spmd); } else { // Tile-shard instruction according to representative op. auto sharding = *ops_sharding[representative_op]; @@ -811,7 +933,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, tile_assignment.Reshape(dimensions); sharding = HloSharding::Tile(tile_assignment); } - return MaybeImproveInstructionSharding(sharding, instruction); + return MaybeImproveInstructionSharding( + sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); } } case HloOpcode::kParameter: { @@ -826,7 +949,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (parent->called_computations()[i - 1] == instruction->parent()) { if (parent->operand(i)->has_sharding()) { return MaybeImproveInstructionSharding( - parent->operand(i)->sharding(), instruction); + parent->operand(i)->sharding(), instruction, + /*may_combine_partial_sharding=*/is_spmd); } return false; } @@ -853,15 +977,16 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (instruction->shape().IsTuple()) { return MaybeImproveInstructionSharding( HloSharding::SingleTuple(instruction->shape(), operand->sharding()), - instruction); + instruction, /*may_combine_partial_sharding=*/is_spmd); } else { - return MaybeImproveInstructionSharding(operand->sharding(), - instruction); + return MaybeImproveInstructionSharding( + operand->sharding(), instruction, + /*may_combine_partial_sharding=*/is_spmd); } } case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: { - auto propagate_slicing = [instruction]() { + auto propagate_slicing = [instruction, is_spmd]() { const HloInstruction* operand = instruction->opcode() == HloOpcode::kDynamicSlice ? instruction->operand(0) @@ -871,8 +996,9 @@ bool InferShardingFromOperands(HloInstruction* instruction, } if (operand->sharding().IsReplicated()) { - return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); + return MaybeImproveInstructionSharding( + HloSharding::Replicate(), instruction, + /*may_combine_partial_sharding=*/is_spmd); } const auto& tile_assignment = operand->sharding().tile_assignment(); @@ -883,10 +1009,11 @@ bool InferShardingFromOperands(HloInstruction* instruction, return false; } } - return MaybeImproveInstructionSharding(operand->sharding(), - instruction); + return MaybeImproveInstructionSharding( + operand->sharding(), instruction, + /*may_combine_partial_sharding=*/is_spmd); }; - auto propagate_base = [instruction]() { + auto propagate_base = [instruction, is_spmd]() { if (instruction->opcode() != HloOpcode::kDynamicUpdateSlice) { return false; } @@ -894,25 +1021,57 @@ bool InferShardingFromOperands(HloInstruction* instruction, return false; } return MaybeImproveInstructionSharding( - instruction->operand(0)->sharding(), instruction); + instruction->operand(0)->sharding(), instruction, + /*may_combine_partial_sharding=*/is_spmd); }; return propagate_slicing() || propagate_base(); } case HloOpcode::kGather: { - if (!IsSpatiallyPartitioned(instruction->operand(1))) { - return false; + bool changed = false; + if (IsSpatiallyPartitioned(instruction->operand(1))) { + HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding( + instruction->operand(1)->sharding(), instruction); + changed |= MaybeImproveInstructionSharding( + new_sharding, instruction, + /*may_combine_partial_sharding=*/is_spmd); } - HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding( - instruction->operand(1)->sharding(), instruction); - return MaybeImproveInstructionSharding(new_sharding, instruction); + if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) { + auto maybe_from_data = + hlo_sharding_util::GatherOutputShardingFromDataOperand( + instruction->operand(0)->sharding(), *instruction); + if (maybe_from_data) { + changed |= MaybeImproveInstructionSharding( + *maybe_from_data, instruction, + /*may_combine_partial_sharding=*/is_spmd); + } + } + return changed; } case HloOpcode::kScatter: { + bool changed = false; + if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) { + changed |= MaybeImproveInstructionSharding( + instruction->operand(0)->sharding(), instruction, + /*may_combine_partial_sharding=*/is_spmd); + } if (!IsSpatiallyPartitioned(instruction->operand(1)) && !IsSpatiallyPartitioned(instruction->operand(2))) { return false; } - return MaybeImproveInstructionSharding(HloSharding::Replicate(), - instruction); + if (is_spmd && IsSpatiallyPartitioned(instruction->operand(2))) { + auto maybe_from_update = + hlo_sharding_util::ScatterOutputShardingFromUpdate( + instruction->operand(2)->sharding(), *instruction); + if (maybe_from_update) { + changed |= MaybeImproveInstructionSharding( + *maybe_from_update, instruction, + /*may_combine_partial_sharding=*/is_spmd); + } + } + changed |= MaybeImproveInstructionSharding( + HloSharding::Replicate(), instruction, + /*may_combine_partial_sharding=*/is_spmd); + return changed; } case HloOpcode::kWhile: { if (!instruction->operand(0)->has_sharding()) { @@ -923,14 +1082,28 @@ bool InferShardingFromOperands(HloInstruction* instruction, sharding = MergeForMoreSpecificSharding(sharding, instruction->sharding()); } - return MaybeImproveInstructionSharding(sharding, instruction); + return MaybeImproveInstructionSharding( + sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); } default: { + if (instruction->IsElementwise() && is_spmd) { + bool changed = false; + for (auto operand : instruction->operands()) { + if (IsSpatiallyPartitioned(operand)) { + changed |= MaybeImproveInstructionSharding( + operand->sharding(), instruction, + /*may_combine_partial_sharding=*/is_spmd); + } + } + return changed; + } const HloInstruction* operand = PickRepresentativeOperand(instruction); if (!operand || !IsSpatiallyPartitioned(operand)) { return false; } - return MaybeImproveInstructionSharding(operand->sharding(), instruction); + return MaybeImproveInstructionSharding( + operand->sharding(), instruction, + /*may_combine_partial_sharding=*/is_spmd); } } return false; @@ -948,25 +1121,25 @@ absl::optional GetShardingFromUser( if (user.sharding().IsReplicated()) { return user.sharding(); } - // Only support when none of the partitioned dimensions in the broadcast - // output belong to new dimensions. + std::vector dims_to_replicate; + bool needs_replication = false; for (int64 i = 0; i < user.shape().rank(); ++i) { - if (user.sharding().tile_assignment().dim(i) > 1 && - absl::c_count(user.dimensions(), i) == 0) { - return absl::nullopt; + if (absl::c_count(user.dimensions(), i) == 0) { + dims_to_replicate.push_back(i); + if (user.sharding().tile_assignment().dim(i) > 1) { + needs_replication = true; + } } } - - // The instruction (operand of broadcast) will be tiled the same way - // as the output. - std::vector target_tile_assignment_dimensions; - for (int64 output_dim : user.dimensions()) { - target_tile_assignment_dimensions.push_back( - user.sharding().tile_assignment().dim(output_dim)); + // If not SPMD, only support when none of the partitioned dimensions in + // the broadcast output belong to new dimensions. + if (!is_spmd && needs_replication) { + return absl::nullopt; } - Array new_tile_assignment = user.sharding().tile_assignment(); - new_tile_assignment.Reshape(target_tile_assignment_dimensions); - return HloSharding::Tile(new_tile_assignment); + return hlo_sharding_util::RemoveShapeDimensions( + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + user.sharding(), dims_to_replicate), + dims_to_replicate); } case HloOpcode::kConcatenate: { if (user.sharding().IsReplicated()) { @@ -1191,10 +1364,11 @@ absl::optional GetShardingFromUser( return user_sharding; } std::vector target_tile_assignment_dimensions( - instruction.shape().rank()); + instruction.shape().rank() + + (user_sharding.ReplicateOnLastTileDim() ? 1 : 0)); const auto& dimensions = user.dimensions(); int64 next_output_dim = 0; - for (int64 i = 0; i < instruction.shape().rank(); ++i) { + for (int64 i = 0; i < target_tile_assignment_dimensions.size(); ++i) { if (absl::c_find(dimensions, i) == dimensions.end()) { target_tile_assignment_dimensions[i] = user_sharding.tile_assignment().dim(next_output_dim++); @@ -1204,7 +1378,9 @@ absl::optional GetShardingFromUser( } auto tile_assignment = user_sharding.tile_assignment(); tile_assignment.Reshape(target_tile_assignment_dimensions); - return HloSharding::Tile(tile_assignment); + return user_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tile_assignment) + : HloSharding::Tile(tile_assignment); } case HloOpcode::kSort: { if (user.sharding().IsTuple()) { @@ -1218,6 +1394,43 @@ absl::optional GetShardingFromUser( return hlo_sharding_util::ReverseSharding(user.sharding(), user.dimensions()); } + case HloOpcode::kGather: { + if (&instruction == user.operand(1)) { + return hlo_sharding_util::GatherIndexSharding(user.sharding(), &user); + } + if (is_spmd) { + return hlo_sharding_util::GatherDataOperandShardingFromOutput( + user.sharding(), user); + } + return absl::nullopt; + } + case HloOpcode::kScatter: { + if (&instruction == user.operand(0)) { + return user.sharding(); + } + if (&instruction == user.operand(1)) { + auto update = user.operand(2); + if (!IsSpatiallyPartitioned(update)) { + return absl::nullopt; + } + return hlo_sharding_util::ScatterIndexSharding(update->sharding(), + &user); + } + CHECK_EQ(&instruction, user.operand(2)); + auto indices = user.operand(1); + if (IsSpatiallyPartitioned(indices)) { + auto from_indices = + hlo_sharding_util::ScatterDataSharding(indices->sharding(), &user); + if (!from_indices.IsTileMaximal()) { + return from_indices; + } + } + if (is_spmd) { + return hlo_sharding_util::ScatterUpdateShardingFromOutput( + user.sharding(), user); + } + return absl::nullopt; + } default: { // If the user output shape is compatible with the current instruction // shape excluding element type and the current instruction is supported @@ -1246,8 +1459,9 @@ bool InferShardingFromUsers(HloInstruction* instruction, absl::optional user_sharding = GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd); if (user_sharding) { - improved_sharding |= - MaybeImproveInstructionSharding(*user_sharding, instruction); + improved_sharding |= MaybeImproveInstructionSharding( + *user_sharding, instruction, + /*may_combine_partial_sharding=*/is_spmd); } } return improved_sharding; diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index d62328aa9ad..a182af001c2 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -118,6 +118,25 @@ ENTRY %broadcast { op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, BroadcastForwardPartial) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[3,2048]parameter(0), + sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %broadcast = f32[3,2048,3] broadcast(%param0), dimensions={0,1} + ROOT %copy = f32[3,2048,3] copy(%broadcast) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "broadcast"), + op::Sharding("{devices=[1,2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, BroadcastUser) { const char* const hlo_string = R"( HloModule module @@ -136,6 +155,25 @@ ENTRY %broadcast { op::Sharding("{devices=[2,4]0,1,2,3,4,5,6,7}")); } +TEST_F(ShardingPropagationTest, BroadcastUserPartial) { + const char* const hlo_string = R"( +HloModule module +ENTRY %broadcast { + %param0 = f32[24,8]{0,1} parameter(0) + %copy = f32[24,8]{0,1} copy(%param0) + ROOT %broadcast = f32[4,24,6,8] broadcast(%copy), dimensions={1,3}, + sharding={devices=[4,2,1,1]0,1,2,3,4,5,6,7} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,1,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, MaximalReduceForwardPass) { const char* const hlo_string = R"( HloModule module @@ -184,6 +222,78 @@ ENTRY %reduce { op::Sharding("{devices=[2,2]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, ReducePartiallyOnTiledDims) { + const char* const hlo_string = R"( +HloModule module +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce { + %param0 = f32[8,8] parameter(0), sharding={devices=[2,2]0,1,2,3} + %init = f32[] parameter(1) + %reduce = f32[8] reduce(%param0, %init), dimensions={0}, to_apply=%add + ROOT %copy = f32[8] copy(%reduce) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "reduce"), + op::Sharding("{devices=[2,2]0,2,1,3 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ReducePartiallyOnTiledDims2) { + const char* const hlo_string = R"( +HloModule module +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce { + %param0 = f32[8,8] parameter(0), sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %init = f32[] parameter(1) + %reduce = f32[8] reduce(%param0, %init), dimensions={0}, to_apply=%add + ROOT %copy = f32[8] copy(%reduce) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "reduce"), + op::Sharding("{devices=[2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ReducePartiallyBackward) { + const char* const hlo_string = R"( +HloModule module +%add { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(%lhs, %rhs) +} +ENTRY %reduce { + %param0 = f32[8,8] parameter(0) + %input = f32[8,8] copy(%param0) + %init = f32[] parameter(1) + %reduce = f32[8] reduce(%input, %init), dimensions={0}, to_apply=%add, + sharding={devices=[2,2]0,1,2,3 last_tile_dim_replicate} + ROOT %copy = f32[8] copy(%reduce) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "input"), + op::Sharding("{devices=[1,2,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, ShardedTupleReduceForwardAndBackwardPass) { const char* const hlo_string = R"( HloModule module @@ -1149,21 +1259,21 @@ ENTRY entry { ShardingPropagation().Run(module.get())); EXPECT_TRUE(changed); EXPECT_THAT(FindInstruction(module.get(), "tp"), - op::Sharding("{{devices=[1,2]0,1}}")); + op::Sharding("{{devices=[3,1]0,1,2}}")); EXPECT_THAT(FindInstruction(module.get(), "tgte"), - op::Sharding("{devices=[1,2]0,1}")); + op::Sharding("{devices=[3,1]0,1,2}")); EXPECT_THAT(FindInstruction(module.get(), "ttr"), - op::Sharding("{devices=[2,1]0,1}")); + op::Sharding("{devices=[1,3]0,1,2}")); EXPECT_THAT(FindInstruction(module.get(), "tr"), - op::Sharding("{{devices=[2,1]0,1}}")); + op::Sharding("{{devices=[1,3]0,1,2}}")); EXPECT_THAT(FindInstruction(module.get(), "fp"), op::Sharding("{{devices=[1,3]0,1,2}}")); EXPECT_THAT(FindInstruction(module.get(), "fgte"), op::Sharding("{devices=[1,3]0,1,2}")); EXPECT_THAT(FindInstruction(module.get(), "fr"), - op::Sharding("{{devices=[2,1]0,1}}")); + op::Sharding("{{devices=[1,3]0,1,2}}")); EXPECT_THAT(FindInstruction(module.get(), "conditional"), - op::Sharding("{{devices=[2,1]0,1}}")); + op::Sharding("{{devices=[1,3]0,1,2}}")); } TEST_F(ShardingPropagationTest, TupleFromUser) { @@ -1494,5 +1604,328 @@ ENTRY entry { op::Sharding("{devices=[2,1,1,1]0,1}")); } +TEST_F(ShardingPropagationTest, GatherFromIndex) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), sharding={devices=[2]0,1} + %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9} + ROOT %copy = f32[3,9] copy(%gather) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "gather"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, GatherFromDataOperand) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1} + %indices = s32[3] parameter(1), sharding={replicated} + %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9} + ROOT %copy = f32[3,9] copy(%gather) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "gather"), + op::Sharding("{devices=[1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, GatherToIndex) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %p1 = s32[3] parameter(1) + %indices = s32[3] copy(%p1) + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, sharding={devices=[2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[2]0,1}")); +} + +TEST_F(ShardingPropagationTest, GatherToDataOperand) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %p0 = f32[2,9] parameter(0) + %input = f32[2,9] copy(%p0) + %indices = s32[3] parameter(1), sharding={replicated} + ROOT %gather = f32[3,9] gather(%input, %indices), offset_dims={1}, + collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, + slice_sizes={1,9}, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "input"), + op::Sharding("{devices=[1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, DataOperandToScatter) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={devices=[1,2]0,1} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={replicated} + %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT %copy = f32[2,9] copy(%scatter) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "scatter"), + op::Sharding("{devices=[1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, UpdateOperandToScatter) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={devices=[1,2]0,1} + %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + ROOT %copy = f32[2,9] copy(%scatter) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "scatter"), + op::Sharding("{devices=[1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, ScatterToDataOperand) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %p0 = f32[2,9] parameter(0) + %input = f32[2,9] copy(%p0) + %indices = s32[3] parameter(1), sharding={replicated} + %updates = f32[3,9] parameter(2), sharding={replicated} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "input"), + op::Sharding("{devices=[1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, ScatterToUpdateOperand) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0) + %indices = s32[3] parameter(1), sharding={replicated} + %p2 = f32[3,9] parameter(2) + %updates = f32[3,9] copy(%p2) + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={devices=[1,2]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "updates"), + op::Sharding("{devices=[1,2]0,1}")); +} + +TEST_F(ShardingPropagationTest, ScatterUpdateToIndex) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %p1 = s32[3] parameter(1), sharding={replicated} + %indices = s32[3] copy(%p1) + %updates = f32[3,9] parameter(2), sharding={devices=[2,1]0,1} + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[2]0,1}")); +} + +TEST_F(ShardingPropagationTest, ScatterIndexToUpdate) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9] parameter(0), sharding={replicated} + %indices = s32[3] parameter(1), sharding={devices=[2]0,1} + %p2 = f32[3,9] parameter(2), sharding={replicated} + %updates = f32[3,9] copy(%p2) + ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "updates"), + op::Sharding("{devices=[2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, PartialShardingOnElementwise) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %p0 = f32[2,9] parameter(0), sharding={devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} + %p1 = f32[2,9] parameter(1), sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate} + %lhs = f32[2,9] copy(%p0) + %rhs = f32[2,9] copy(%p1) + %add = f32[2,9] add(%lhs, %rhs) + ROOT %copy = f32[2,9] copy(%add) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "lhs"), + op::Sharding("{devices=[2,2]0,2,1,3}")); + EXPECT_THAT(FindInstruction(module.get(), "rhs"), + op::Sharding("{devices=[2,2]0,2,1,3}")); + EXPECT_THAT(FindInstruction(module.get(), "add"), + op::Sharding("{devices=[2,2]0,2,1,3}")); +} + +TEST_F(ShardingPropagationTest, PartialShardingOnElementwise2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %p0 = f32[2,9] parameter(0), sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %p1 = f32[2,9] parameter(1), sharding={devices=[2,1,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate} + %lhs = f32[2,9] copy(%p0) + %rhs = f32[2,9] copy(%p1) + %add = f32[2,9] add(%lhs, %rhs) + ROOT %copy = f32[2,9] copy(%add) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "lhs"), + op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); + EXPECT_THAT( + FindInstruction(module.get(), "rhs"), + op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); + EXPECT_THAT( + FindInstruction(module.get(), "add"), + op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD index 4433078472d..ce19934bb88 100644 --- a/tensorflow/compiler/xla/service/spmd/BUILD +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -17,6 +17,8 @@ package_group( cc_library( name = "spmd_partitioner", srcs = [ + "convolution_handler.cc", + "dot_handler.cc", "spmd_partitioner.cc", "spmd_partitioner_util.cc", ], @@ -48,6 +50,8 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/core/platform:numbers", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/xla/service/spmd/convolution_handler.cc b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc new file mode 100644 index 00000000000..01d7ea2ff14 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/convolution_handler.cc @@ -0,0 +1,1013 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/numbers.h" + +namespace xla { +namespace spmd { +namespace { + +// Partition convolution. +StatusOr PartitionConvolution( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const Window& conv_window, + HloInstruction* original_hlo, int64 num_partitions, + const SpmdPartitionerOptions& options, HloInstruction* partition_id, + HloModule* module, SpmdBuilder* b); + +// Partition convolution with only paralell dims are tiled +StatusOr PartitionConvolutionWithParallelDimension( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const Window& conv_window, + HloInstruction* original_hlo, int64 num_partitions, SpmdBuilder* b) { + TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); + + const auto& dnums = original_hlo->convolution_dimension_numbers(); + std::vector rhs_to_lhs_indices(output_base_shape.rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(output_base_shape.rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + // Handling cases where all the partitioned dimensions are parallel + // dimensions. + int64 lhs_parallel_dim_partitions = 1; + int64 rhs_parallel_dim_partitions = 1; + std::vector parallel_spatial_dims; + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dim = dnums.input_spatial_dimensions(i); + int64 lhs_size = lhs.base_shape().dimensions(lhs_dim); + const auto& wd = conv_window.dimensions(i); + int64 rhs_dim = dnums.kernel_spatial_dimensions(i); + if (dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) { + parallel_spatial_dims.emplace_back(i); + lhs_parallel_dim_partitions *= ShardCountAtDim(lhs.sharding(), lhs_dim); + rhs_parallel_dim_partitions *= ShardCountAtDim(rhs.sharding(), rhs_dim); + } + } + bool lhs_partition_dims_are_parallel = + (lhs_parallel_dim_partitions == num_partitions); + bool rhs_partition_dims_are_parallel = + (rhs_parallel_dim_partitions == num_partitions); + + // If there is a parallel dim and all the partitioned dimensions are parallel + // dimensions in either LHS or RHS, simply create partitioned convolutions. + if (parallel_spatial_dims.empty() || ((!lhs_partition_dims_are_parallel) && + (!rhs_partition_dims_are_parallel))) { + return nullptr; + } + // Reshard LHS or RHS to partition at parallel dimensions as the other + // operand. + if (lhs_partition_dims_are_parallel) { + rhs = rhs.Reshard(aligned_rhs_sharding); + } else { + lhs = lhs.Reshard(aligned_lhs_sharding); + } + + // Get LHS and RHS sharded shape. + auto lhs_shard_shape = MakePartitionedShape(lhs.base_shape(), lhs.sharding()); + auto rhs_shard_shape = MakePartitionedShape(rhs.base_shape(), rhs.sharding()); + + // Update convolution window. + auto new_window = conv_window; + for (const auto& spatial_dim : parallel_spatial_dims) { + auto wd = new_window.mutable_dimensions(spatial_dim); + wd->set_size(lhs_shard_shape.dimensions( + dnums.input_spatial_dimensions(spatial_dim))); + wd->set_stride(std::max(1, wd->size() - 1)); + wd->set_base_dilation(wd->size()); + } + TF_ASSIGN_OR_RETURN( + Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + lhs_shard_shape, rhs_shard_shape, original_hlo->feature_group_count(), + original_hlo->batch_group_count(), new_window, dnums)); + auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve( + sharded_conv_shape, lhs.hlo(), rhs.hlo(), + original_hlo->feature_group_count(), original_hlo->batch_group_count(), + new_window, dnums, original_hlo->precision_config())); + sharded_conv->set_sharding(original_hlo->sharding()); + return PartitionedHlo(sharded_conv, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + +// Partition convolution when both LHS and RHS are partitioned at spatial +// dimensions. Halo exchange will happen on RHS only. +StatusOr +PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const Window& conv_window, + HloInstruction* original_hlo, HloInstruction* partition_id, + HloModule* module, SpmdBuilder* b) { + TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); + TF_RET_CHECK(!lhs.sharding().IsTileMaximal() && + !rhs.sharding().IsTileMaximal()); + + const auto& dnums = original_hlo->convolution_dimension_numbers(); + std::vector rhs_to_lhs_indices(output_base_shape.rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(output_base_shape.rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + auto aligned_rhs_sharding = + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); + auto aligned_lhs_sharding = + hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); + + auto unsupported_sharding = [&](const HloSharding& lhs_sharding, + const HloSharding& rhs_sharding) { + // We currently don't support partitioning input batch or output feature + // dimensions. + return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != + 1 || + rhs_sharding.tile_assignment().dim( + dnums.kernel_output_feature_dimension()) != 1; + }; + + auto zero = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(output_base_shape.element_type()))); + if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) { + if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { + return nullptr; + } + lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { + return nullptr; + } + lhs = lhs.PadWithValue(zero); + rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); + } + + // Reshard RHS so that each shard computes the partial sum of the full + // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs() + // that reshards LHS. + // + // The size of halo on each dimension can be calculated from the + // projection onto the RHS that shard i needs to read. RHS and LHS below + // refers to the shard size of RHS and LHS, WC is the number of windows, + // and D is the window dilation. + // + // * offset(i): LHS * i + low_padding - (WC - 1) * stride + // * limit(i): LHS * (i + 1) + low_padding + // + // Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D) + // * left-halo: i * RHS - offset(i) + // = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding + // * right-halo: limit(i) - (i + 1) * RHS + // = (i + 1) * (LHS - RHS * D) + low_pading + const auto& collective_ops_creator = lhs.state().collective_ops_creator; + std::vector shard_counts(dnums.input_spatial_dimensions_size()); + std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); + std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); + + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); + auto wd = conv_window.dimensions(i); + if (wd.base_dilation() != 1 || wd.window_reversal()) { + return nullptr; + } + + int64 lhs_shard_size = + CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count); + int64 rhs_shard_size = + CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count); + shard_counts[i] = shard_count; + lhs_shard_sizes[i] = lhs_shard_size; + rhs_shard_sizes[i] = rhs_shard_size; + } + + std::vector left_halo_size_functions( + output_base_shape.rank()); + std::vector right_halo_size_functions( + output_base_shape.rank()); + Window new_window = conv_window; + + // Data structures needed for Pad and DynamicSlice on LHS if needed. + bool need_dynamic_slice_lhs = false; + auto partition_ordinals = + MakeTiledPartitionOrdinals(lhs.sharding(), partition_id, b); + std::vector zero_padding(output_base_shape.rank()); + PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding); + auto zero_s32 = + b->AddInstruction(HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + std::vector dynamic_slice_start_indices( + output_base_shape.rank(), zero_s32); + Shape dynamic_slice_shape = lhs.hlo()->shape(); + Shape pad_shape = lhs.hlo()->shape(); + + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 lhs_shard_size = lhs_shard_sizes[i]; + int64 rhs_shard_size = rhs_shard_sizes[i]; + + if (shard_counts[i] == 1) { + continue; + } + + // Calculate the left and right halo sizes as described in the comments + // above. It calculcates the halo sizes with dilation, so we apply + // CeilOfRatio({left,right}_halo_size, window_dilation). + auto wd = conv_window.dimensions(i); + int64 padding_low = wd.padding_low(); + int64 padding_high = wd.padding_high(); + int64 base = lhs.base_shape().dimensions(lhs_dimension); + int64 window_count = 1 + (padding_low + padding_high + base - + (1 + (wd.size() - 1) * wd.window_dilation())) / + wd.stride(); + left_halo_size_functions[rhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + (window_count - 1) * wd.stride() - padding_low + + wd.window_dilation() - 1, + wd.window_dilation())); + right_halo_size_functions[rhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + lhs_shard_size - rhs_shard_size * wd.window_dilation(), + lhs_shard_size - rhs_shard_size * wd.window_dilation() + + padding_low + wd.window_dilation() - 1, + wd.window_dilation())); + + // New RHS window size includes the maximum of both left and right + // halos. + int64 halo_size = + left_halo_size_functions[rhs_dimension].MaxInRange(1, shard_counts[i]) + + right_halo_size_functions[rhs_dimension].MaxInRange( + 0, shard_counts[i] - 1); + int64 new_window_size = + rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size; + + // The amount of new low padding could be dynamic (e.g., window_dilation + // != 1), which requires pad (to the maximum) and dynamic slice on LHS. + // + // If we consider the first window, the offset of the dilated RHS that + // aligns with the first valid LHS element for shard i is 'padding_low + + // LHS * i'. When the left halo is added to RHS, the offset of the first + // RHS element is (RHS * i - left_halo) * window_dilation. The + // difference between the two values is the amount of padding_low we + // need on LHS. + auto new_padding_low_function = + OffsetCalculation(HloOpcode::kMultiply, + left_halo_size_functions[rhs_dimension], + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + 0, wd.window_dilation(), 1))) - + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + -padding_low, 1)); + + int64 new_padding_low_max = + new_padding_low_function.MaxInRange(0, shard_counts[i]); + int64 new_padding_low = new_padding_low_max; + int64 new_padding_high = window_count * wd.stride() + + (new_window_size - 1) * wd.window_dilation() - + new_padding_low - lhs_shard_size; + + // We do pad/dynamic-slice only when the padding is dynamic. + if (!new_padding_low_function.IsConstant()) { + need_dynamic_slice_lhs = true; + new_padding_low = 0; + pad_config.mutable_dimensions(lhs_dimension) + ->set_edge_padding_low(new_padding_low_max); + pad_config.mutable_dimensions(lhs_dimension) + ->set_edge_padding_high(new_padding_low_max); + pad_shape.set_dimensions(lhs_dimension, + lhs_shard_size + 2 * new_padding_low_max); + dynamic_slice_start_indices[lhs_dimension] = + (OffsetCalculation( + MultiplyAddDivideOffsetCalculation(0, new_padding_low_max, 1)) - + new_padding_low_function) + .Calculate(partition_ordinals[lhs_dimension], b); + dynamic_slice_shape.set_dimensions(lhs_dimension, + lhs_shard_size + new_padding_low_max); + } + + // Since the convolution RHS operand size increased with halos, adjust + // the window config accordingly. + new_window.mutable_dimensions(i)->set_padding_low(new_padding_low); + new_window.mutable_dimensions(i)->set_padding_high(new_padding_high); + new_window.mutable_dimensions(i)->set_size( + rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size); + } + + HloInstruction* conv_lhs = lhs.hlo(); + if (need_dynamic_slice_lhs) { + auto pad = b->AddInstruction( + HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config)); + conv_lhs = b->AddInstruction(HloInstruction::CreateDynamicSlice( + dynamic_slice_shape, pad, dynamic_slice_start_indices, + dynamic_slice_shape.dimensions())); + } + + // Exchange halo and concatenate. + HloInstruction* rhs_with_halo = rhs.hlo(); + for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { + int64 dim = dnums.kernel_spatial_dimensions(i); + int64 explicit_left_padding_on_full_shape = + left_halo_size_functions[dim].Calculate(0); + int64 shard_size_with_halo = new_window.dimensions(i).size(); + + // offset_on_padded_shape and padded_full_shape_size are needed only if + // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). + // Since the default value for both the collective-permute is zero and + // also we call PadWithValue() on both operands at the beginning, we + // don't need to mask here. + // + // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls + // if it's always safe. + auto offset_on_padded_shape = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) - + left_halo_size_functions[dim]; + int64 padded_full_shape_size = + offset_on_padded_shape.Calculate(shard_counts[i] - 1) + + new_window.dimensions(i).size(); + auto concat = ExchangeHaloAndGetValidData( + rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding_on_full_shape, + padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(), + offset_on_padded_shape.Calculate(partition_ordinals[dim], b), zero, + partition_ordinals[dim], collective_ops_creator, + lhs.state().next_channel_id, b, + /*mask_invalid_region=*/false); + if (!concat) { + return nullptr; + } + rhs_with_halo = *concat; + } + + auto conv = b->AddInstruction(HloInstruction::CreateConvolve( + output_base_shape, conv_lhs, rhs_with_halo, + original_hlo->feature_group_count(), original_hlo->batch_group_count(), + new_window, dnums, original_hlo->precision_config())); + auto ar = collective_ops_creator.create_cross_partition_all_reduce( + b, conv, MakeBinaryAdd(original_hlo->shape().element_type(), module), {}, + (*lhs.state().next_channel_id)++); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + +// Partition convolution when both LHS and RHS are partitioned at spatial +// dimensions. Halo exchange will happen on LHS only. +StatusOr +PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const Window& conv_window, + HloInstruction* original_hlo, HloInstruction* partition_id, + HloModule* module, SpmdBuilder* b) { + TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); + TF_RET_CHECK(!lhs.sharding().IsTileMaximal() && + !rhs.sharding().IsTileMaximal()); + + const auto& dnums = original_hlo->convolution_dimension_numbers(); + + // Check if the operand shardings are aligned. Also we currently don't + // support partitioning non-spatial dimensions. + std::vector rhs_to_lhs_indices(output_base_shape.rank()); + rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = + dnums.input_batch_dimension(); + rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = + dnums.input_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = + dnums.input_spatial_dimensions(i); + } + std::vector lhs_to_rhs_indices(output_base_shape.rank()); + for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { + lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; + } + + Window window = conv_window; + std::vector reversed_rhs_dims; + for (int64 i = 0; i < window.dimensions_size(); ++i) { + if (window.dimensions(i).window_reversal()) { + reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i)); + } + } + if (!reversed_rhs_dims.empty()) { + // Make the reversed dims left-padded to prepare for window reversal. + auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims); + if (left_padded_rhs == nullptr) { + return nullptr; + } + left_padded_rhs->set_sharding(rhs.sharding()); + rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state()); + } + // Consider window reversal when resharding RHS or LHS. Note: this will not + // reverse the data in the shard. We use window reversal to do that. + auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding( + hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices), + reversed_rhs_dims); + auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding( + hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims), + lhs_to_rhs_indices); + + auto unsupported_sharding = [&](const HloSharding& lhs_sharding, + const HloSharding& rhs_sharding) { + return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != + 1 || + rhs_sharding.tile_assignment().dim( + dnums.kernel_output_feature_dimension()) != 1; + }; + + auto zero = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(output_base_shape.element_type()))); + if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) { + if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { + return nullptr; + } + lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); + rhs = rhs.PadWithValue(zero, reversed_rhs_dims); + } else { + if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { + return nullptr; + } + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims); + } + + // Reshard LHS by exchanging halo such that each shard computes the partial + // sum of the full shape result, and add AllReduce. + // + // The size of halo on each dimension can be calculated from the projection + // onto the LHS that each RHS shard i needs to read. RHS and LHS below refers + // to the shard size of RHS and LHS, WC is the number of windows, and D is the + // window dilation. + // + // * offset(i): RHS * D * i - low_padding + // * limit(i): {RHS * (i + 1) * D - (D - 1)} + (WC - 1) * stride - low_padding + // + // Since shard i has LHS of range [i * LHS, (i + 1) * LHS) + // * left-halo: i * LHS - offset(i) + // = (LHS - RHS * D) * i + low_padding + // * right-halo: limit(i) - (i + 1) * LHS + // = (RHS * D - LHS) * (i + 1) + (1 - D) + (WC - 1) * stride - low_padding + // = (RHS * D - LHS) * i + (RHS * D - LHS) + (1-D) + // + (WC - 1) * stride - low_padding + std::vector shard_counts(dnums.input_spatial_dimensions_size()); + std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); + std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); + int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension); + auto wd = window.dimensions(i); + if (wd.base_dilation() != 1) { + // TODO(wangtao): support parallel dim if it is replicate here. + return nullptr; + } + + int64 lhs_shard_size = + CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count); + int64 rhs_shard_size = + CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count); + shard_counts[i] = shard_count; + lhs_shard_sizes[i] = lhs_shard_size; + rhs_shard_sizes[i] = rhs_shard_size; + } + + std::vector left_halo_size_functions( + output_base_shape.rank()); + std::vector right_halo_size_functions( + output_base_shape.rank()); + Window new_window = window; + + auto partition_ordinals = + MakeTiledPartitionOrdinals(lhs.sharding(), partition_id, b); + HloInstruction* lhs_with_halo = lhs.hlo(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dimension = dnums.input_spatial_dimensions(i); + int64 lhs_shard_size = lhs_shard_sizes[i]; + int64 rhs_shard_size = rhs_shard_sizes[i]; + + if (shard_counts[i] == 1) { + continue; + } + + // Calculate the left and right halo sizes as described in the comments + // above. + auto wd = window.dimensions(i); + int64 padding_low = wd.padding_low(); + int64 padding_high = wd.padding_high(); + int64 base = lhs.base_shape().dimensions(lhs_dimension); + int64 window_count = 1 + (padding_low + padding_high + base - + (1 + (wd.size() - 1) * wd.window_dilation())) / + wd.stride(); + int64 rhs_shard_size_dilated = + (rhs_shard_size - 1) * wd.window_dilation() + 1; + + left_halo_size_functions[lhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low, + 1)); + right_halo_size_functions[lhs_dimension] = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + rhs_shard_size * wd.window_dilation() - lhs_shard_size, + rhs_shard_size * wd.window_dilation() - lhs_shard_size + 1 - + wd.window_dilation() + wd.stride() * (window_count - 1) - + padding_low, + 1)); + + // Exchange halo and concatenate. + int64 dim = dnums.input_spatial_dimensions(i); + int64 explicit_left_padding_on_full_shape = padding_low; + int64 shard_size_with_halo = + wd.stride() * (window_count - 1) + rhs_shard_size_dilated; + + new_window.mutable_dimensions(i)->set_padding_low(0); + new_window.mutable_dimensions(i)->set_padding_high(0); + new_window.mutable_dimensions(i)->set_size(rhs_shard_size); + + // offset_on_padded_shape and padded_full_shape_size are needed only if + // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). + // Since the default value for both the collective-permute is zero and + // also we call PadWithValue() on both operands at the beginning, we + // don't need to mask here. + // + // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls + // if it's always safe. + auto offset_on_padded_shape = + OffsetCalculation(MultiplyAddDivideOffsetCalculation()); + int64 padded_full_shape_size = 0; + auto concat = ExchangeHaloAndGetValidData( + lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim], + right_halo_size_functions[dim], explicit_left_padding_on_full_shape, + padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(), + offset_on_padded_shape.Calculate(partition_ordinals[dim], b), zero, + partition_ordinals[dim], lhs.state().collective_ops_creator, + lhs.state().next_channel_id, b, + /*mask_invalid_region=*/false); + if (!concat) { + return nullptr; + } + lhs_with_halo = *concat; + } + + auto conv = b->AddInstruction(HloInstruction::CreateConvolve( + output_base_shape, lhs_with_halo, rhs.hlo(), + original_hlo->feature_group_count(), original_hlo->batch_group_count(), + new_window, original_hlo->convolution_dimension_numbers(), + original_hlo->precision_config())); + auto ar = + lhs.state().collective_ops_creator.create_cross_partition_all_reduce( + b, conv, MakeBinaryAdd(output_base_shape.element_type(), module), {}, + (*lhs.state().next_channel_id)++); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + +// Partition convolution when output is sharded. Will shard LHS with replicated +// RHS. +StatusOr PartitionConvolutionTiledOutput( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const Window& conv_window, + HloInstruction* original_hlo, SpmdBuilder* b) { + TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); + const auto& dnums = original_hlo->convolution_dimension_numbers(); + TF_RET_CHECK(!output_sharding.IsTileMaximal()); + // We don't currently support sharding on output feature dimension. + if (output_sharding.tile_assignment().dim(dnums.output_feature_dimension()) > + 1) { + return nullptr; + } + + // Check if the operand and the output sharding are aligned. + std::vector input_to_output_indices(output_base_shape.rank()); + input_to_output_indices[dnums.input_batch_dimension()] = + dnums.output_batch_dimension(); + input_to_output_indices[dnums.input_feature_dimension()] = + dnums.output_feature_dimension(); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + input_to_output_indices[dnums.input_spatial_dimensions(i)] = + dnums.output_spatial_dimensions(i); + } + auto target_operand_sharding = hlo_sharding_util::TransposeSharding( + output_sharding, input_to_output_indices); + lhs = lhs.Reshard(target_operand_sharding); + + // Replicate the RHS. + rhs = rhs.Reshard(HloSharding::Replicate()); + + // Convolution window config does not include batch and feature dimensions, + // whereas ReshardAsWindowedInput() expects the same number of window + // dimensions as the rank of the operand. So add two more trivial + // dimensions. + std::vector ones(output_base_shape.rank(), 1); + auto operand_window = window_util::MakeWindow(ones); + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + *operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) = + conv_window.dimensions(i); + } + + auto zero = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(output_base_shape.element_type()))); + auto resharded_operand_and_window = + lhs.ReshardAsWindowedInput(operand_window, target_operand_sharding, zero); + if (!resharded_operand_and_window.has_value()) { + return nullptr; + } + Window new_window; + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + *new_window.add_dimensions() = + resharded_operand_and_window->shard_window.dimensions( + dnums.input_spatial_dimensions(i)); + } + TF_ASSIGN_OR_RETURN( + Shape sharded_conv_shape, + ShapeInference::InferConvolveShape( + resharded_operand_and_window->sharded_input->shape(), + rhs.hlo()->shape(), original_hlo->feature_group_count(), + original_hlo->batch_group_count(), new_window, dnums)); + auto shard_shape = MakePartitionedShape(output_base_shape, output_sharding); + *sharded_conv_shape.mutable_layout() = shard_shape.layout(); + auto sharded_conv = b->AddInstruction(HloInstruction::CreateConvolve( + sharded_conv_shape, resharded_operand_and_window->sharded_input, + rhs.hlo(), original_hlo->feature_group_count(), + original_hlo->batch_group_count(), new_window, dnums, + original_hlo->precision_config())); + if (!resharded_operand_and_window->dynamic_slice_index_on_output + .has_value()) { + CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape())); + return sharded_conv; + } + return b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, sharded_conv, + *resharded_operand_and_window->dynamic_slice_index_on_output, + shard_shape.dimensions())); +} + +StatusOr PartitionConvolutionGroupOnParallelDim( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const Window& conv_window, + HloInstruction* original_hlo, const ConvolutionDimsMapping& dims_mapping, + int64 num_partitions, const SpmdPartitionerOptions& options, + HloInstruction* partition_id, HloModule* module, SpmdBuilder* b) { + std::vector lhs_dims; + std::vector rhs_dims; + std::vector output_dims; + auto lhs_sharding_dims_adjusted_to_output = + lhs.sharding().IsReplicated() + ? std::vector(lhs.base_shape().rank(), 1) + : lhs.sharding().tile_assignment().dimensions(); + auto rhs_sharding_dims_adjusted_to_output = + rhs.sharding().IsReplicated() + ? std::vector(rhs.base_shape().rank(), 1) + : rhs.sharding().tile_assignment().dimensions(); + auto output_sharding_dims_adjusted_to_lhs = + output_sharding.tile_assignment().dimensions(); + bool lhs_rhs_dims_matching = true; + for (const auto& dim : dims_mapping.parallel_spatial_dims) { + lhs_dims.push_back(dim.lhs); + rhs_dims.push_back(dim.rhs); + output_dims.push_back(dim.output); + if (lhs_sharding_dims_adjusted_to_output[dim.lhs] != + rhs_sharding_dims_adjusted_to_output[dim.rhs]) { + lhs_rhs_dims_matching = false; + } + lhs_sharding_dims_adjusted_to_output[dim.lhs] = + output_sharding.tile_assignment().dim(dim.output); + rhs_sharding_dims_adjusted_to_output[dim.rhs] = + output_sharding.tile_assignment().dim(dim.output); + output_sharding_dims_adjusted_to_lhs[dim.output] = + lhs.sharding().tile_assignment().dim(dim.lhs); + } + auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims); + auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims); + auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); + if (lhs_rhs_dims_matching) { + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) > + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped); + rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); + } else { + lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped); + lhs = lhs.Reshard(UngroupSharding(lhs_grouped)); + } + auto reshaped_output_tiling = output_sharding.tile_assignment(); + reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs); + output_grouped = AlignGroupsWith( + GroupShardingOnDims(HloSharding::Tile(reshaped_output_tiling), + output_dims), + lhs_grouped); + } else { + auto reshaped_lhs_tiling = lhs.sharding().tile_assignment(); + reshaped_lhs_tiling.Reshape(lhs_sharding_dims_adjusted_to_output); + lhs_grouped = AlignGroupsWith( + GroupShardingOnDims(HloSharding::Tile(reshaped_lhs_tiling), lhs_dims), + output_grouped); + lhs = lhs.Reshard(UngroupSharding(lhs_grouped)); + auto reshaped_rhs_tiling = rhs.sharding().tile_assignment(); + reshaped_rhs_tiling.Reshape(rhs_sharding_dims_adjusted_to_output); + rhs_grouped = AlignGroupsWith( + GroupShardingOnDims(HloSharding::Tile(reshaped_rhs_tiling), rhs_dims), + output_grouped); + rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); + } + + // Update LHS and RHS sharding and shape. + lhs.hlo()->set_sharding(lhs_grouped.sharding); + rhs.hlo()->set_sharding(rhs_grouped.sharding); + CHECK(lhs.hlo() != rhs.hlo() || lhs_grouped.sharding == rhs_grouped.sharding); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + lhs.state(), lhs_grouped.device_groups, b); + auto grouped_lhs_base_shape = + GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()); + auto grouped_lhs_shard_shape = + MakePartitionedShape(grouped_lhs_base_shape, lhs.sharding()); + // Update convolution window with the new shape + auto new_window = conv_window; + for (const auto& dim : dims_mapping.parallel_spatial_dims) { + auto wd = new_window.mutable_dimensions(dim.spatial); + wd->set_size(grouped_lhs_shard_shape.dimensions(dim.lhs)); + wd->set_stride(std::max(1, wd->size() - 1)); + wd->set_base_dilation(wd->size()); + } + + auto new_partition_id = + lhs.state().collective_ops_creator.create_partition_id(b); + TF_ASSIGN_OR_RETURN( + auto conv, + PartitionConvolution( + PartitionedHlo(lhs.hlo(), grouped_lhs_base_shape, + per_group_partitioner_state), + PartitionedHlo(rhs.hlo(), + GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), + per_group_partitioner_state), + GetPerGroupBaseShape(output_grouped, output_base_shape), + output_grouped.sharding, new_window, original_hlo, + num_partitions / output_grouped.device_groups.size(), options, + new_partition_id, module, b)); + // Reset the LHS sharding to the ungrouped one. + lhs.hlo()->set_sharding(UngroupSharding(lhs_grouped)); + rhs.hlo()->set_sharding(UngroupSharding(rhs_grouped)); + conv->set_sharding(UngroupSharding(output_grouped)); + return PartitionedHlo(conv, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + +// Partition convolution with only one kind of dims partitioned. +StatusOr PartitionConvolutionBaseCase( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const Window& conv_window, + HloInstruction* original_hlo, int64 num_partitions, + const SpmdPartitionerOptions& options, HloInstruction* partition_id, + HloModule* module, SpmdBuilder* b) { + TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); + + // Case 1: Either RHS or LHS is only partitioned at parallel dimensions. + TF_ASSIGN_OR_RETURN(auto parallel_partitioned_conv, + PartitionConvolutionWithParallelDimension( + lhs, rhs, output_base_shape, output_sharding, + conv_window, original_hlo, num_partitions, b)); + if (parallel_partitioned_conv) { + return parallel_partitioned_conv; + } + + // Case 2: both RHS and LHS are tiled. + // Handling cases where both operands' shardings are aligned. We check that + // the LHS batch dimension is not partitioned because it is mapped to the + // output feature dimension in aligned_rhs_sharding, which are not the same + // dimension. + if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) { + if (options.conv_halo_exchange_always_on_lhs) { + TF_ASSIGN_OR_RETURN( + auto partitioned_conv, + PartitionConvolutionWithSpatialDimensionHaloExchangeOnLHS( + lhs, rhs, output_base_shape, output_sharding, conv_window, + original_hlo, partition_id, module, b)); + if (partitioned_conv) { + return partitioned_conv; + } + } else { + TF_ASSIGN_OR_RETURN( + auto partitioned_conv, + PartitionConvolutionWithSpatialDimensionHaloExchangeOnRHS( + lhs, rhs, output_base_shape, output_sharding, conv_window, + original_hlo, partition_id, module, b)); + + if (partitioned_conv) { + return partitioned_conv; + } + } + } + + // Case 3: output is tiled. + if (!output_sharding.IsTileMaximal()) { + TF_ASSIGN_OR_RETURN(auto partitioned_conv, + PartitionConvolutionTiledOutput( + lhs, rhs, output_base_shape, output_sharding, + conv_window, original_hlo, b)); + + if (partitioned_conv) { + return partitioned_conv; + } + } + return nullptr; +} + +// Partition convolution. +StatusOr PartitionConvolution( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, const Window& conv_window, + HloInstruction* original_hlo, int64 num_partitions, + const SpmdPartitionerOptions& options, HloInstruction* partition_id, + HloModule* module, SpmdBuilder* b) { + TF_RET_CHECK(original_hlo->opcode() == HloOpcode::kConvolution); + + TF_ASSIGN_OR_RETURN( + auto try_partitioned_conv, + PartitionConvolutionBaseCase(lhs, rhs, output_base_shape, output_sharding, + conv_window, original_hlo, num_partitions, + options, partition_id, module, b)); + if (try_partitioned_conv) { + return try_partitioned_conv; + } + + const auto& dnums = original_hlo->convolution_dimension_numbers(); + spmd::ConvolutionDimsMapping mapping; + for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { + int64 lhs_dim = dnums.input_spatial_dimensions(i); + int64 lhs_size = lhs.base_shape().dimensions(lhs_dim); + const auto& wd = original_hlo->window().dimensions(i); + int64 rhs_dim = dnums.kernel_spatial_dimensions(i); + int64 output_dim = dnums.output_spatial_dimensions(i); + if (dot_as_convolution_util::ConvSpatialDimensionIsParallel(wd, lhs_size)) { + mapping.parallel_spatial_dims.emplace_back(); + mapping.parallel_spatial_dims.back().lhs = lhs_dim; + mapping.parallel_spatial_dims.back().rhs = rhs_dim; + mapping.parallel_spatial_dims.back().output = output_dim; + mapping.parallel_spatial_dims.back().spatial = i; + } else { + mapping.non_parallel_spatial_dims.emplace_back(); + mapping.non_parallel_spatial_dims.back().lhs = lhs_dim; + mapping.non_parallel_spatial_dims.back().rhs = rhs_dim; + mapping.non_parallel_spatial_dims.back().output = output_dim; + mapping.non_parallel_spatial_dims.back().spatial = i; + } + } + + // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. + auto get_partitions_for_dims = + [&](const HloSharding& sharding, + absl::Span dims, + int lhs_rhs_or_output) { + int64 partitions = 1; + if (sharding.IsTileMaximal()) { + return partitions; + } + for (const auto& dim : dims) { + if (lhs_rhs_or_output == 0) { + partitions *= sharding.tile_assignment().dim(dim.lhs); + } else if (lhs_rhs_or_output == 1) { + partitions *= sharding.tile_assignment().dim(dim.rhs); + } else { + CHECK_EQ(lhs_rhs_or_output, 2); + partitions *= sharding.tile_assignment().dim(dim.output); + } + } + return partitions; + }; + + const int64 lhs_parallel_spatial_partitions = + get_partitions_for_dims(lhs.sharding(), mapping.parallel_spatial_dims, 0); + const int64 rhs_parallel_spatial_partitions = + get_partitions_for_dims(rhs.sharding(), mapping.parallel_spatial_dims, 1); + const int64 output_parallel_spatial_partitions = get_partitions_for_dims( + original_hlo->sharding(), mapping.parallel_spatial_dims, 2); + + // Recursively partition on different types of dimensions. + // + // Case 1: Group partitions by parallel spatial dims. + if (lhs_parallel_spatial_partitions == rhs_parallel_spatial_partitions && + lhs_parallel_spatial_partitions == output_parallel_spatial_partitions && + lhs_parallel_spatial_partitions > 1) { + TF_ASSIGN_OR_RETURN(auto try_partitioned_conv, + PartitionConvolutionGroupOnParallelDim( + lhs, rhs, output_base_shape, output_sharding, + conv_window, original_hlo, mapping, num_partitions, + options, partition_id, module, b)); + if (try_partitioned_conv) { + return try_partitioned_conv; + } + } + + return nullptr; +} + +} // namespace + +Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { + auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo); + if (dot_dnums) { + // Use HandleDotHelper() for convs that are actually einsums. + spmd::DotGeneralDimsMapping mapping; + for (const auto& dims : dot_dnums->batch_dims) { + mapping.batch_dims.emplace_back(); + mapping.batch_dims.back().lhs = dims.lhs; + mapping.batch_dims.back().rhs = dims.rhs; + mapping.batch_dims.back().output = dims.output; + } + for (const auto& dims : dot_dnums->contracting_dims) { + mapping.contracting_dims.emplace_back(); + mapping.contracting_dims.back().lhs = dims.lhs; + mapping.contracting_dims.back().rhs = dims.rhs; + mapping.contracting_dims.back().output = dims.output; + } + for (const auto& dims : dot_dnums->lhs_non_contracting_dims) { + mapping.lhs_non_contracting_dims.emplace_back(); + mapping.lhs_non_contracting_dims.back().lhs = dims.lhs; + mapping.lhs_non_contracting_dims.back().rhs = dims.rhs; + mapping.lhs_non_contracting_dims.back().output = dims.output; + } + for (const auto& dims : dot_dnums->rhs_non_contracting_dims) { + mapping.rhs_non_contracting_dims.emplace_back(); + mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; + mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; + mapping.rhs_non_contracting_dims.back().output = dims.output; + } + auto create_sharded_conv = + [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, + spmd::SpmdBuilder* b) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharded_conv, + dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution( + *hlo, *dot_dnums, lhs_hlo, rhs_hlo)); + return b->AddInstruction(std::move(sharded_conv)); + }; + return HandleDotHelper(hlo, mapping, create_sharded_conv); + } + + auto lhs = GetPartitionedHlo(hlo->operand(0)); + auto rhs = GetPartitionedHlo(hlo->operand(1)); + TF_ASSIGN_OR_RETURN( + auto partitioned_conv, + PartitionConvolution(lhs, rhs, hlo->shape(), hlo->sharding(), + hlo->window(), hlo, num_partitions_, options_, + partition_id_, module_, &b_)); + + if (partitioned_conv) { + SetPartitionedHlo(hlo, [&] { return partitioned_conv; }); + return Status::OK(); + } + return DefaultAction(hlo); +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc new file mode 100644 index 00000000000..55ebe120d01 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -0,0 +1,1587 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" +#include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/numbers.h" + +namespace xla { +namespace spmd { + +Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { + DotGeneralDimsMapping mapping; + const auto& dnums = hlo->dot_dimension_numbers(); + int64 next_output_dim = 0; + for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { + mapping.batch_dims.emplace_back(); + mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i); + mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i); + mapping.batch_dims.back().output = next_output_dim++; + } + for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { + mapping.contracting_dims.emplace_back(); + mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i); + mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i); + mapping.contracting_dims.back().output = -1; + } + for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) { + if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) || + absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) { + continue; + } + mapping.lhs_non_contracting_dims.emplace_back(); + mapping.lhs_non_contracting_dims.back().lhs = i; + mapping.lhs_non_contracting_dims.back().rhs = -1; + mapping.lhs_non_contracting_dims.back().output = next_output_dim++; + } + for (int64 i = 0; i < hlo->operand(1)->shape().rank(); ++i) { + if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) || + absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) { + continue; + } + mapping.rhs_non_contracting_dims.emplace_back(); + mapping.rhs_non_contracting_dims.back().lhs = -1; + mapping.rhs_non_contracting_dims.back().rhs = i; + mapping.rhs_non_contracting_dims.back().output = next_output_dim++; + } + auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r, + SpmdBuilder* b) -> StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharded_dot_shape, + ShapeInference::InferDotOpShape(l->shape(), r->shape(), + hlo->dot_dimension_numbers())); + return b->AddInstruction(HloInstruction::CreateDot( + sharded_dot_shape, l, r, hlo->dot_dimension_numbers(), + hlo->precision_config())); + }; + return HandleDotHelper(hlo, mapping, create_sharded_dot); +} + +namespace { + +StatusOr PartitionBaseCase( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, + const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, int64 lhs_batch_partitions, + int64 rhs_batch_partitions, int64 output_batch_partitions, + int64 lhs_contracting_partitions, int64 rhs_contracting_partitions, + int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions, + int64 output_lhs_non_contracting_partitions, + int64 output_rhs_non_contracting_partitions, + int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops) { + const HloSharding& lhs_sharding = lhs.sharding(); + const HloSharding& rhs_sharding = rhs.sharding(); + std::vector lhs_to_rhs_indices(lhs.base_shape().rank(), -1); + std::vector lhs_to_output_indices(lhs.base_shape().rank(), -1); + std::vector rhs_to_lhs_indices(rhs.base_shape().rank(), -1); + std::vector rhs_to_output_indices(rhs.base_shape().rank(), -1); + std::vector output_to_lhs_indices(output_base_shape.rank(), -1); + std::vector output_to_rhs_indices(output_base_shape.rank(), -1); + auto populate_indices_mapping = + [&](const DotGeneralDimsMapping::DimsMapping& mapping) { + if (mapping.lhs >= 0) { + lhs_to_rhs_indices[mapping.lhs] = mapping.rhs; + lhs_to_output_indices[mapping.lhs] = mapping.output; + } + if (mapping.rhs >= 0) { + rhs_to_lhs_indices[mapping.rhs] = mapping.lhs; + rhs_to_output_indices[mapping.rhs] = mapping.output; + } + if (mapping.output >= 0) { + output_to_lhs_indices[mapping.output] = mapping.lhs; + output_to_rhs_indices[mapping.output] = mapping.rhs; + } + }; + for (const auto& mapping : dims_mapping.batch_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.contracting_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) { + populate_indices_mapping(mapping); + } + for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) { + populate_indices_mapping(mapping); + } + auto lhs_sharding_transposed_to_match_rhs = + TransposeShardingWithCollapsedDims(lhs_sharding, lhs_to_rhs_indices, + rhs_to_lhs_indices); + auto rhs_sharding_transposed_to_match_lhs = + TransposeShardingWithCollapsedDims(rhs_sharding, rhs_to_lhs_indices, + lhs_to_rhs_indices); + auto lhs_sharding_transposed_to_match_output = + TransposeShardingWithCollapsedDims(lhs_sharding, lhs_to_output_indices, + output_to_lhs_indices); + auto rhs_sharding_transposed_to_match_output = + TransposeShardingWithCollapsedDims(rhs_sharding, rhs_to_output_indices, + output_to_rhs_indices); + auto output_sharding_transposed_to_match_lhs = + TransposeShardingWithCollapsedDims(output_sharding, output_to_lhs_indices, + lhs_to_output_indices); + auto output_sharding_transposed_to_match_rhs = + TransposeShardingWithCollapsedDims(output_sharding, output_to_rhs_indices, + rhs_to_output_indices); + + // LHS and RHS are partitioned the same way and only partitioned in batch + // dimensions. + if (lhs_batch_partitions == rhs_batch_partitions && + rhs_batch_partitions == num_partitions && + lhs_sharding_transposed_to_match_rhs == rhs_sharding) { + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b)); + dot->set_sharding(*lhs_sharding_transposed_to_match_output); + return PartitionedHlo(dot, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); + } + + // Try emit batch-partitioned einsum with one operand resharded. Returns + // partitioned HLO or nullptr if the attempt fails. If + // may_reshard_with_allreduce is false, reshard must be done using + // all-to-all/collective-permute; otherwise this attempt fails. + auto try_emit_output_batch_partitioned_einsum_with_reshard = + [&](bool may_reshard_with_allreduce) -> StatusOr { + // LHS and output are batch partitioned in the same way. + if (lhs_batch_partitions == num_partitions && + output_batch_partitions == num_partitions && + lhs_sharding_transposed_to_match_output == output_sharding) { + if (!may_reshard_with_allreduce && + !CanReshardWithCollectivePermute( + rhs.sharding(), *lhs_sharding_transposed_to_match_rhs) && + !GetReshardAllToAllSourceTargetDims( + rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) { + return nullptr; + } + auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), b)); + return dot; + } + // RHS and output are batch partitioned in the same way. + if (rhs_batch_partitions == num_partitions && + output_batch_partitions == num_partitions && + rhs_sharding_transposed_to_match_output == output_sharding) { + if (!may_reshard_with_allreduce && + !CanReshardWithCollectivePermute( + lhs.sharding(), *rhs_sharding_transposed_to_match_lhs) && + !GetReshardAllToAllSourceTargetDims( + lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) { + return nullptr; + } + auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); + TF_ASSIGN_OR_RETURN( + auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), b)); + return dot; + } + return nullptr; + }; + + { + // Try batch-parallel by resharding one operand, and not using all-reduce. + TF_ASSIGN_OR_RETURN( + HloInstruction * partitioned_dot, + try_emit_output_batch_partitioned_einsum_with_reshard(false)); + if (partitioned_dot) { + return partitioned_dot; + } + } + + // Try to emit windowed DotGeneral when one operand is partitioned in the same + // way as the output along non-contracting dimensions, but the other operand + // is tiled in other dimensions. + auto emit_windowed_dot_general = + [&](int64 matching_operand, int64 windowing_operand, + bool windowed_at_contracting_dims, + bool windowed_at_batch_dims) -> StatusOr { + CHECK_EQ(matching_operand + windowing_operand, 1); + CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims); + auto unpadded_result_buffer_shape = + MakePartitionedShape(output_base_shape, output_sharding); + auto padded_result_buffer_shape = unpadded_result_buffer_shape; + // For windowing at batch/non-contracting dims, we produce the result one + // partition at a time, so we need to pad the shape in case of uneven + // partitioning in order to make dynamic-update-slice in-bound. + if (!windowed_at_contracting_dims) { + padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning( + padded_result_buffer_shape, + windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output + : *rhs_sharding_transposed_to_match_output); + } + // Mask the padding area of the windowed operand with zero if there is + // uneven partitioning. + if (windowed_at_contracting_dims) { + auto& to_mask = windowing_operand == 0 ? lhs : rhs; + to_mask = + to_mask.PadWithValue(b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(output_base_shape.element_type())))); + } + auto result_buffer = CreateZero(padded_result_buffer_shape, b); + auto iteration = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + + // Create a while loop that computes one window per iteration. During each + // iteration, each partition sends its input window to its neighbor using + // collective-permute for the next iteration. + SpmdBuilder body_b("windowed_dot_general_body", original_hlo); + auto param = body_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), + result_buffer->shape(), iteration->shape()}), + "param")); + auto l = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(lhs.hlo()->shape(), param, 0)); + auto r = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(rhs.hlo()->shape(), param, 1)); + auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), param, 2)); + auto i = body_b.AddInstruction( + HloInstruction::CreateGetTupleElement(iteration->shape(), param, 3)); + + auto partition_id = + lhs.state().collective_ops_creator.create_partition_id(&body_b); + auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, partition_id)); + auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions))); + data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kRemainder, data_partition_id, partition_count)); + auto dot_lhs = l; + auto dot_rhs = r; + if (windowed_at_contracting_dims || windowed_at_batch_dims) { + // Slice the matching operand according to the partitioned contracting + // dimensions on the windowed operand. We do this by treating the matching + // operand as replicated, and resharding it to match the windowed operand. + auto slice_operand = matching_operand == 0 ? l : r; + slice_operand->set_sharding(HloSharding::Replicate()); + auto state = lhs.state(); + state.b = &body_b; + state.partition_id = data_partition_id; + auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state) + .Reshard(windowing_operand == 0 + ? *lhs_sharding_transposed_to_match_rhs + : *rhs_sharding_transposed_to_match_lhs) + .hlo(); + slice_operand->clear_sharding(); + if (matching_operand == 0) { + dot_lhs = slice; + } else { + dot_rhs = slice; + } + } + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(dot_lhs, dot_rhs, &body_b)); + if (windowed_at_contracting_dims) { + // Accumulate the partial output to the result buffer. + o = body_b.AddInstruction( + HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot)); + } else { + // The windowing operand is partitioned along batch/non-contracting + // dimensions, so we need a dynamic-update-slice to save the partial + // output in the result buffer. + auto offsets = MakePartitionOffsets( + o->shape(), + windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output + : *rhs_sharding_transposed_to_match_output, + data_partition_id, &body_b); + o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + o->shape(), o, dot, offsets)); + } + + // ++i + i = body_b.AddInstruction(HloInstruction::CreateBinary( + i->shape(), HloOpcode::kAdd, i, + body_b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))))); + auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), i, + body_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions))), + ComparisonDirection::kLt)); + // Collective-permute for the next window. We don't need it for the last + // iteration, so we use a conditional around the collective-permute. + HloInstruction* conditional; + { + SpmdBuilder cp_b("window_collective_permute", original_hlo); + { + auto p = cp_b.AddInstruction(HloInstruction::CreateParameter( + 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); + std::vector> sd_pairs(num_partitions); + for (int64 source = 0; source < num_partitions; ++source) { + // 0 -> n-1, 1 -> 0, 2 -> 1, ... + sd_pairs[source] = {source, + (source - 1 + num_partitions) % num_partitions}; + } + lhs.state() + .collective_ops_creator.create_cross_partition_collective_permute( + &cp_b, p, sd_pairs, (*lhs.state().next_channel_id)++); + } + SpmdBuilder ncp_b("last_iteration_noop", original_hlo); + { + ncp_b.AddInstruction(HloInstruction::CreateParameter( + 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); + } + conditional = body_b.AddInstruction(HloInstruction::CreateConditional( + windowing_operand == 0 ? l->shape() : r->shape(), has_more, + windowing_operand == 0 ? l : r, + module->AddEmbeddedComputation(cp_b.Build()), + windowing_operand == 0 ? l : r, + module->AddEmbeddedComputation(ncp_b.Build()))); + } + if (windowing_operand == 0) { + l = conditional; + } else { + r = conditional; + } + body_b.AddInstruction(HloInstruction::CreateTuple({l, r, o, i})); + + SpmdBuilder cond_b("windowed_dot_general_cond", original_hlo); + auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, + ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), + result_buffer->shape(), iteration->shape()}), + "param")); + auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement( + iteration->shape(), cond_param, 3)); + cond_b.AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), cond_i, + cond_b.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(num_partitions))), + ComparisonDirection::kLt)); + auto while_loop = b->AddInstruction(HloInstruction::CreateWhile( + cond_param->shape(), module->AddEmbeddedComputation(cond_b.Build()), + module->AddEmbeddedComputation(body_b.Build()), + b->AddInstruction(HloInstruction::CreateTuple( + {lhs.hlo(), rhs.hlo(), result_buffer, iteration})))); + windowed_dot_general_loops->push_back({while_loop, windowing_operand, + windowed_at_contracting_dims, + windowed_at_batch_dims}); + auto result = b->AddInstruction(HloInstruction::CreateGetTupleElement( + result_buffer->shape(), while_loop, 2)); + if (!ShapeUtil::Compatible(padded_result_buffer_shape, + unpadded_result_buffer_shape)) { + result = b->AddInstruction(HloInstruction::CreateSlice( + unpadded_result_buffer_shape, result, + std::vector(padded_result_buffer_shape.rank(), 0), + unpadded_result_buffer_shape.dimensions(), + std::vector(padded_result_buffer_shape.rank(), 1))); + } + return result; + }; + if (output_lhs_non_contracting_partitions == num_partitions && + output_sharding_transposed_to_match_lhs == lhs_sharding && + ShapeSizeInBytes(rhs.base_shape()) >= + threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (rhs_contracting_partitions == num_partitions) { + return emit_windowed_dot_general(0, 1, true, false); + } + if (rhs_non_contracting_partitions == num_partitions) { + return emit_windowed_dot_general(0, 1, false, false); + } + if (rhs_batch_partitions == num_partitions) { + return emit_windowed_dot_general(0, 1, false, true); + } + } + if (output_rhs_non_contracting_partitions == num_partitions && + output_sharding_transposed_to_match_rhs == rhs_sharding && + ShapeSizeInBytes(lhs.base_shape()) >= + threshold_for_windowed_einsum_mib * 1024 * 1024) { + if (lhs_contracting_partitions == num_partitions) { + return emit_windowed_dot_general(1, 0, true, false); + } + if (lhs_non_contracting_partitions == num_partitions) { + return emit_windowed_dot_general(1, 0, false, false); + } + if (lhs_batch_partitions == num_partitions) { + return emit_windowed_dot_general(1, 0, false, true); + } + } + + { + // Try batch-parallel by resharding one operand, and allowing all-reduce. + TF_ASSIGN_OR_RETURN( + HloInstruction * partitioned_dot, + try_emit_output_batch_partitioned_einsum_with_reshard(true)); + if (partitioned_dot) { + return partitioned_dot; + } + } + + // LHS and RHS have the same partitioned contracting dimensions. + if (lhs_contracting_partitions == rhs_contracting_partitions && + lhs_contracting_partitions == num_partitions) { + auto zero = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(output_base_shape.element_type()))); + // Pad both sides with zero, since NaN at one side cannot be masked by zero + // on the other side. + if (ShapeSizeInBytes(lhs.base_shape()) < + ShapeSizeInBytes(rhs.base_shape())) { + lhs = + lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); + } + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b)); + auto ar = + lhs.state().collective_ops_creator.create_cross_partition_all_reduce( + b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {}, + (*lhs.state().next_channel_id)++); + ar->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(ar, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); + } + + // LHS and output have the same partitioned non-contracting dimensions. + if (lhs_non_contracting_partitions == num_partitions && + output_lhs_non_contracting_partitions == num_partitions && + lhs_sharding_transposed_to_match_output == output_sharding) { + auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs.hlo(), rhs_replicated, b)); + return dot; + } + + // RHS and output have the same partitioned non-contracting dimensions. + if (rhs_non_contracting_partitions == num_partitions && + output_rhs_non_contracting_partitions == num_partitions && + rhs_sharding_transposed_to_match_output == output_sharding) { + auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); + TF_ASSIGN_OR_RETURN(auto dot, + create_sharded_dot(lhs_replicated, rhs.hlo(), b)); + return dot; + } + + // Output is batch partitioned. + if (output_batch_partitions == num_partitions) { + auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), + resharded_rhs.hlo(), b)); + return dot; + } + // Output is partitioned along LHS non-contracting dimensions. + if (output_lhs_non_contracting_partitions == num_partitions) { + auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), + replicated_rhs.hlo(), b)); + return dot; + } + // Output is partitioned along RHS non-contracting dimensions. + if (output_rhs_non_contracting_partitions == num_partitions) { + auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); + auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), + resharded_rhs.hlo(), b)); + return dot; + } + + // Returns true if it is beneficial to reshard the operand at `operand_idx` + // across the contracting dimension. + const auto should_partition_contracting_dim = [&](int64 operand_idx) { + if (!output_sharding.IsReplicated()) { + return false; + } + + if (operand_idx == 0) { + // If LHS and output are replicated, we compare the cost of all-gather + // on RHS vs all-reduce on the output. + return (rhs_contracting_partitions == num_partitions) && + lhs.sharding().IsReplicated() && + ShapeUtil::ElementsIn(rhs.base_shape()) > + ShapeUtil::ElementsIn(output_base_shape); + } else { + return (lhs_contracting_partitions == num_partitions) && + rhs.sharding().IsReplicated() && + ShapeUtil::ElementsIn(lhs.base_shape()) > + ShapeUtil::ElementsIn(output_base_shape); + } + }; + + // When the output is replicated and one of the operands is partitioned along + // contracting dimension, align the other operand to be partitioned along + // the contracting dimensions. + if (output_sharding.IsReplicated() && (should_partition_contracting_dim(0) || + should_partition_contracting_dim(1))) { + auto zero = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(output_base_shape.element_type()))); + if (should_partition_contracting_dim(0)) { + lhs = + lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); + rhs = rhs.PadWithValue(zero); + } else { + lhs = lhs.PadWithValue(zero); + rhs = + rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); + } + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.hlo(), rhs.hlo(), b)); + return lhs.state().collective_ops_creator.create_cross_partition_all_reduce( + b, dot, MakeBinaryAdd(output_base_shape.element_type(), module), {}, + (*lhs.state().next_channel_id)++); + } + return nullptr; +} + +StatusOr PartitionDot( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, + const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, + int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops); + +StatusOr PartitionDotGroupOnBatch( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, + const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + int64 lhs_contracting_partitions, int64 rhs_contracting_partitions, + int64 lhs_non_contracting_partitions, int64 rhs_non_contracting_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, + int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops) { + std::vector lhs_dims; + std::vector rhs_dims; + std::vector output_dims; + auto lhs_sharding_dims_adjusted_to_output = + lhs.sharding().IsReplicated() + ? std::vector(lhs.base_shape().rank(), 1) + : lhs.sharding().tile_assignment().dimensions(); + auto rhs_sharding_dims_adjusted_to_output = + rhs.sharding().IsReplicated() + ? std::vector(rhs.base_shape().rank(), 1) + : rhs.sharding().tile_assignment().dimensions(); + auto output_sharding_dims_adjusted_to_lhs = + output_sharding.tile_assignment().dimensions(); + bool lhs_rhs_dims_matching = true; + for (const auto& dim : dims_mapping.batch_dims) { + lhs_dims.push_back(dim.lhs); + rhs_dims.push_back(dim.rhs); + output_dims.push_back(dim.output); + if (lhs_sharding_dims_adjusted_to_output[dim.lhs] != + rhs_sharding_dims_adjusted_to_output[dim.rhs]) { + lhs_rhs_dims_matching = false; + } + lhs_sharding_dims_adjusted_to_output[dim.lhs] = + output_sharding.tile_assignment().dim(dim.output); + rhs_sharding_dims_adjusted_to_output[dim.rhs] = + output_sharding.tile_assignment().dim(dim.output); + output_sharding_dims_adjusted_to_lhs[dim.output] = + lhs.sharding().tile_assignment().dim(dim.lhs); + } + auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); + PartitionedHlo per_group_lhs = lhs; + PartitionedHlo per_group_rhs = rhs; + auto lhs_sharding = lhs.sharding(); + auto rhs_sharding = rhs.sharding(); + if (lhs_rhs_dims_matching) { + auto lhs_grouped = GroupShardingOnDims(lhs.sharding(), lhs_dims); + auto rhs_grouped = GroupShardingOnDims(rhs.sharding(), rhs_dims); + if (ShapeUtil::ByteSizeOf(lhs.base_shape()) > + ShapeUtil::ByteSizeOf(rhs.base_shape())) { + rhs_grouped = AlignGroupsWith(std::move(rhs_grouped), lhs_grouped); + rhs = rhs.Reshard(UngroupSharding(rhs_grouped)); + } else { + lhs_grouped = AlignGroupsWith(std::move(lhs_grouped), rhs_grouped); + lhs = lhs.Reshard(UngroupSharding(lhs_grouped)); + } + auto reshaped_output_tiling = output_sharding.tile_assignment(); + reshaped_output_tiling.Reshape(output_sharding_dims_adjusted_to_lhs); + output_grouped = AlignGroupsWith( + GroupShardingOnDims(HloSharding::Tile(reshaped_output_tiling), + output_dims), + lhs_grouped); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + lhs.state(), lhs_grouped.device_groups, b); + lhs.hlo()->set_sharding(lhs_grouped.sharding); + rhs.hlo()->set_sharding(rhs_grouped.sharding); + CHECK(lhs.hlo() != rhs.hlo() || + lhs_grouped.sharding == rhs_grouped.sharding); + per_group_lhs = PartitionedHlo( + lhs.hlo(), GetPerGroupBaseShape(lhs_grouped, lhs.base_shape()), + per_group_partitioner_state); + per_group_rhs = PartitionedHlo( + rhs.hlo(), GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), + per_group_partitioner_state); + } else { + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + lhs.state(), output_grouped.device_groups, b); + auto reshard_to_output_batch = + [&](PartitionedHlo operand, absl::Span batch_dims, + absl::Span contracting_dims, + absl::Span non_contracting_dims, + int64 contracting_dim_partitions, + int64 non_contracting_dim_partitions, + int64 other_contracting_dim_partitions, + std::vector* sharding_dims_adjusted_to_output) + -> absl::optional { + if (operand.sharding().IsReplicated()) { + auto partially_sharded = PerGroupSliceFromReplicated( + operand.hlo(), operand.state().partition_id, + output_grouped.device_groups, batch_dims, + output_grouped.group_dim_sizes, b); + partially_sharded->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(partially_sharded, partially_sharded->shape(), + per_group_partitioner_state); + } + auto reshaped_tiling = operand.sharding().tile_assignment(); + // It's possible that the operand is not initially sharded on batch + // dimensions in the same way as the output, although being tiled. In that + // case, the current sharding_dims_adjusted_to_output may contain more + // partitions than available devices. We remove partitioning on other + // dimensions. + if (Product(*sharding_dims_adjusted_to_output) > + reshaped_tiling.num_elements()) { + if (Product(*sharding_dims_adjusted_to_output) % + reshaped_tiling.num_elements() != + 0) { + return absl::nullopt; + } + int64 ratio = Product(*sharding_dims_adjusted_to_output) / + reshaped_tiling.num_elements(); + if (ratio == non_contracting_dim_partitions && + (ratio != contracting_dim_partitions || + contracting_dim_partitions == other_contracting_dim_partitions)) { + for (int64 dim : non_contracting_dims) { + (*sharding_dims_adjusted_to_output)[dim] = 1; + } + } else if (ratio == contracting_dim_partitions) { + for (int64 dim : contracting_dims) { + (*sharding_dims_adjusted_to_output)[dim] = 1; + } + } + } + // If the operand is initially sharded more ways than the output in the + // batch dimensions, sharding_dims_adjusted_to_output currently contains + // fewer partitions than available devices. We do not handle this case. + if (Product(*sharding_dims_adjusted_to_output) < + reshaped_tiling.num_elements()) { + return absl::nullopt; + } + reshaped_tiling.Reshape(*sharding_dims_adjusted_to_output); + auto grouped = AlignGroupsWith( + GroupShardingOnDims(HloSharding::Tile(reshaped_tiling), batch_dims), + output_grouped); + auto resharded = operand.Reshard(UngroupSharding(grouped)); + resharded.hlo()->set_sharding(grouped.sharding); + return PartitionedHlo(resharded.hlo(), + GetPerGroupBaseShape(grouped, operand.base_shape()), + per_group_partitioner_state); + }; + std::vector lhs_contracting_dims; + std::vector rhs_contracting_dims; + lhs_contracting_dims.reserve(dims_mapping.contracting_dims.size()); + rhs_contracting_dims.reserve(dims_mapping.contracting_dims.size()); + for (const auto& dim : dims_mapping.contracting_dims) { + lhs_contracting_dims.push_back(dim.lhs); + rhs_contracting_dims.push_back(dim.rhs); + } + std::vector lhs_non_contracting_dims; + std::vector rhs_non_contracting_dims; + lhs_non_contracting_dims.reserve( + dims_mapping.lhs_non_contracting_dims.size()); + rhs_non_contracting_dims.reserve( + dims_mapping.rhs_non_contracting_dims.size()); + for (const auto& dim : dims_mapping.lhs_non_contracting_dims) { + lhs_non_contracting_dims.push_back(dim.lhs); + } + for (const auto& dim : dims_mapping.rhs_non_contracting_dims) { + rhs_non_contracting_dims.push_back(dim.rhs); + } + if (auto resharded = reshard_to_output_batch( + lhs, lhs_dims, lhs_contracting_dims, lhs_non_contracting_dims, + lhs_contracting_partitions, lhs_non_contracting_partitions, + rhs_contracting_partitions, + &lhs_sharding_dims_adjusted_to_output)) { + per_group_lhs = *resharded; + } else { + return nullptr; + } + if (auto resharded = reshard_to_output_batch( + rhs, rhs_dims, rhs_contracting_dims, rhs_non_contracting_dims, + rhs_contracting_partitions, rhs_non_contracting_partitions, + lhs_contracting_partitions, + &rhs_sharding_dims_adjusted_to_output)) { + per_group_rhs = *resharded; + } else { + return nullptr; + } + CHECK(lhs.hlo() != rhs.hlo() || + per_group_lhs.sharding() == per_group_rhs.sharding()); + } + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDot(per_group_lhs, per_group_rhs, + GetPerGroupBaseShape(output_grouped, output_base_shape), + output_grouped.sharding, dims_mapping, + num_partitions / output_grouped.device_groups.size(), + create_sharded_dot, module, original_hlo, + threshold_for_windowed_einsum_mib, b, + windowed_dot_general_loops)); + // Make sure the operands' sharding are set to the ungrouped ones. + lhs.hlo()->set_sharding(lhs_sharding); + rhs.hlo()->set_sharding(rhs_sharding); + dot->set_sharding(UngroupSharding(output_grouped)); + return PartitionedHlo(dot, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + +StatusOr PartitionDotGroupOnNonContracting( + bool lhs_matching, PartitionedHlo matching, PartitionedHlo other, + int64 matching_contracting_partitions, int64 other_contracting_partitions, + int64 matching_non_contracting_partitions, + int64 other_non_contracting_partitions, + int64 output_other_non_contracting_partitions, + const Shape& output_base_shape, const HloSharding& output_sharding, + const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, + int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops) { + const bool may_replicate_other_contracting_dims = + (other_contracting_partitions == matching_non_contracting_partitions && + other_non_contracting_partitions == + output_other_non_contracting_partitions); + const bool may_replicate_other_non_contracting_dims = + matching_non_contracting_partitions == other_non_contracting_partitions && + matching_contracting_partitions == other_contracting_partitions; + std::vector other_group_dims; + if (may_replicate_other_contracting_dims && + (!may_replicate_other_non_contracting_dims || + ShapeUtil::ByteSizeOf(other.base_shape()) <= + ShapeUtil::ByteSizeOf(output_base_shape))) { + for (const auto& dim : dims_mapping.contracting_dims) { + other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); + } + } else if (may_replicate_other_non_contracting_dims) { + for (const auto& dim : lhs_matching + ? dims_mapping.rhs_non_contracting_dims + : dims_mapping.lhs_non_contracting_dims) { + other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); + } + } else if (!other.sharding().IsReplicated()) { + return nullptr; + } + auto matching_sharding_dims = + matching.sharding().tile_assignment().dimensions(); + std::vector matching_dims; + std::vector output_dims; + // Make sure the partitioning on matching's non-contracting dimensions + // defines the same device groups for both matching and output. + for (const auto& dim : lhs_matching ? dims_mapping.lhs_non_contracting_dims + : dims_mapping.rhs_non_contracting_dims) { + int64 md = lhs_matching ? dim.lhs : dim.rhs; + matching_sharding_dims[md] = + output_sharding.tile_assignment().dim(dim.output); + matching_dims.push_back(md); + output_dims.push_back(dim.output); + } + auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); + auto reshaped_matching_tiling = matching.sharding().tile_assignment(); + reshaped_matching_tiling.Reshape(matching_sharding_dims); + auto matching_grouped = AlignGroupsWith( + GroupShardingOnDims(HloSharding::Tile(reshaped_matching_tiling), + matching_dims), + output_grouped); + matching = matching.Reshard(UngroupSharding(matching_grouped)); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + matching.state(), matching_grouped.device_groups, b); + matching.hlo()->set_sharding(matching_grouped.sharding); + auto matching_p = PartitionedHlo( + matching.hlo(), + GetPerGroupBaseShape(matching_grouped, matching.base_shape()), + per_group_partitioner_state); + + auto partially_replicated_other = other.hlo(); + if (!other.sharding().IsReplicated()) { + auto other_grouped = + AlignGroupsWith(GroupShardingOnDims(other.sharding(), other_group_dims), + output_grouped, /*ignore_group_order=*/true); + other = other.Reshard(UngroupSharding(other_grouped)); + partially_replicated_other = + other.ReplicatePartial(other_grouped.group_dims); + partially_replicated_other->set_sharding(other_grouped.sharding); + } + auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(), + per_group_partitioner_state); + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDot(lhs_matching ? matching_p : other_p, + lhs_matching ? other_p : matching_p, + GetPerGroupBaseShape(output_grouped, output_base_shape), + output_grouped.sharding, dims_mapping, + num_partitions / matching_grouped.device_groups.size(), + create_sharded_dot, module, original_hlo, + threshold_for_windowed_einsum_mib, b, + windowed_dot_general_loops)); + // Reset matching's sharding to the ungrouped one. + matching.hlo()->set_sharding(UngroupSharding(matching_grouped)); + return dot; +} + +// Recursive partitioning function. If there are partial dimensions matching in +// the operands and output, group the devices and recursively partition the +// in-group dot. +StatusOr PartitionDot( + PartitionedHlo lhs, PartitionedHlo rhs, const Shape& output_base_shape, + const HloSharding& output_sharding, + const DotGeneralDimsMapping& dims_mapping, int64 num_partitions, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot, + HloModule* module, HloInstruction* original_hlo, + int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, + std::vector* + windowed_dot_general_loops) { + // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. + auto get_partitions_for_dims = + [&](const HloSharding& sharding, + absl::Span dims, + int lhs_rhs_or_output) { + int64 partitions = 1; + if (sharding.IsTileMaximal()) { + return partitions; + } + for (const auto& dim : dims) { + if (lhs_rhs_or_output == 0) { + partitions *= sharding.tile_assignment().dim(dim.lhs); + } else if (lhs_rhs_or_output == 1) { + partitions *= sharding.tile_assignment().dim(dim.rhs); + } else { + CHECK_EQ(lhs_rhs_or_output, 2); + partitions *= sharding.tile_assignment().dim(dim.output); + } + } + return partitions; + }; + const int64 lhs_batch_partitions = + get_partitions_for_dims(lhs.sharding(), dims_mapping.batch_dims, 0); + const int64 rhs_batch_partitions = + get_partitions_for_dims(rhs.sharding(), dims_mapping.batch_dims, 1); + const int64 output_batch_partitions = + get_partitions_for_dims(output_sharding, dims_mapping.batch_dims, 2); + const int64 lhs_contracting_partitions = + get_partitions_for_dims(lhs.sharding(), dims_mapping.contracting_dims, 0); + const int64 rhs_contracting_partitions = + get_partitions_for_dims(rhs.sharding(), dims_mapping.contracting_dims, 1); + const int64 lhs_non_contracting_partitions = get_partitions_for_dims( + lhs.sharding(), dims_mapping.lhs_non_contracting_dims, 0); + const int64 rhs_non_contracting_partitions = get_partitions_for_dims( + rhs.sharding(), dims_mapping.rhs_non_contracting_dims, 1); + const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims( + output_sharding, dims_mapping.lhs_non_contracting_dims, 2); + const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( + output_sharding, dims_mapping.rhs_non_contracting_dims, 2); + TF_ASSIGN_OR_RETURN( + auto try_partitioned_dot, + PartitionBaseCase( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, module, original_hlo, + lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions, + lhs_contracting_partitions, rhs_contracting_partitions, + lhs_non_contracting_partitions, rhs_non_contracting_partitions, + output_lhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, + threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + if (try_partitioned_dot) { + return try_partitioned_dot; + } + + // Recursively partition on different types of dimensions. + // + // Case 1: Group partitions by batch. + if ((lhs_batch_partitions == output_batch_partitions || + rhs_batch_partitions == output_batch_partitions) && + output_batch_partitions > 1) { + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnBatch( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, lhs_contracting_partitions, + rhs_contracting_partitions, lhs_non_contracting_partitions, + rhs_non_contracting_partitions, create_sharded_dot, module, + original_hlo, threshold_for_windowed_einsum_mib, b, + windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + + // Case 2: Group partitions by non-contracting dimensions. + const bool may_group_on_lhs_non_contracting = + lhs_non_contracting_partitions == output_lhs_non_contracting_partitions && + lhs_non_contracting_partitions > 1; + const bool may_group_on_rhs_non_contracting = + rhs_non_contracting_partitions == output_rhs_non_contracting_partitions && + rhs_non_contracting_partitions > 1; + if (may_group_on_lhs_non_contracting || may_group_on_rhs_non_contracting) { + // If both match output non-contracting dimensions, choose the one which + // will result in smaller replication of the other operand. + const bool lhs_matching = + may_group_on_lhs_non_contracting && + (!may_group_on_rhs_non_contracting || + lhs_non_contracting_partitions * + ShapeUtil::ByteSizeOf(rhs.hlo()->shape()) <= + rhs_non_contracting_partitions * + ShapeUtil::ByteSizeOf(lhs.hlo()->shape())); + + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnNonContracting( + lhs_matching, lhs_matching ? lhs : rhs, lhs_matching ? rhs : lhs, + lhs_matching ? lhs_contracting_partitions + : rhs_contracting_partitions, + lhs_matching ? rhs_contracting_partitions + : lhs_contracting_partitions, + lhs_matching ? lhs_non_contracting_partitions + : rhs_non_contracting_partitions, + lhs_matching ? rhs_non_contracting_partitions + : lhs_non_contracting_partitions, + lhs_matching ? output_rhs_non_contracting_partitions + : output_lhs_non_contracting_partitions, + output_base_shape, output_sharding, dims_mapping, num_partitions, + create_sharded_dot, module, original_hlo, + threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + + // Default action. + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(lhs.Replicate().hlo(), + rhs.Replicate().hlo(), b)); + dot->set_sharding(HloSharding::Replicate()); + return PartitionedHlo(dot, output_base_shape, lhs.state()) + .Reshard(output_sharding) + .hlo(); +} + +} // namespace + +Status SpmdPartitioningVisitor::HandleDotHelper( + HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, + const std::function( + HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) { + auto& lhs = GetPartitionedHlo(hlo->operand(0)); + auto& rhs = GetPartitionedHlo(hlo->operand(1)); + TF_ASSIGN_OR_RETURN( + auto partitioned_dot, + PartitionDot(lhs, rhs, hlo->shape(), hlo->sharding(), dims_mapping, + num_partitions_, create_sharded_dot, module_, hlo, + options_.threshold_for_windowed_einsum_mib, &b_, + &windowed_dot_general_loops_)); + SetPartitionedHlo(hlo, [&] { return partitioned_dot; }); + return Status::OK(); +} + +namespace { + +// Finds a cluster of nodes that produce the inputs for `hlo` which only depend +// on small operands, which means the cluster should start with broadcasts, +// constants and iotas. All other internal nodes must be non-side-effecting +// elemntwise ops. Returns the set of nodes, and the small operands. E.g., for +// the following graph, +// +// a -> broadcast -> multiply +// iota ---> add--/ +// constant/ +// +// FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return +// <{broadcast, iota, constant, add, multiply}, [a]>. +std::pair, std::vector> +FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) { + absl::flat_hash_set nodes_found; + std::vector new_operands; + absl::flat_hash_set new_operands_set; + std::vector worklist; + worklist.push_back(hlo); + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (nodes_found.count(inst) > 0) { + continue; + } + if (inst->opcode() == HloOpcode::kBroadcast || + inst->opcode() == HloOpcode::kConstant || + inst->opcode() == HloOpcode::kIota) { + nodes_found.insert(inst); + for (auto o : inst->operands()) { + auto res = new_operands_set.emplace(o); + if (res.second) { + new_operands.push_back(o); + } + } + } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() && + inst->opcode() != HloOpcode::kAllReduce && + absl::c_all_of(inst->operands(), + [inst](const HloInstruction* o) { + return ShapeUtil::CompatibleIgnoringElementType( + o->shape(), inst->shape()); + })) { + nodes_found.insert(inst); + for (auto o : inst->operands()) { + worklist.push_back(o); + } + } else { + nodes_found.clear(); + new_operands.clear(); + break; + } + } + return {std::move(nodes_found), std::move(new_operands)}; +} + +// Moves a cluster of memory-reducing nodes into the windowed dot-general loop +// on contracting dimensions. Such a loop has a dynamic slice on the +// non-windowed operand. If we move the input nodes into the loop, the +// dynamic-slice could be merged with them by later optimization passes, which +// reduces memory. +// +// small_operands small_operands +// | | +// input_nodes loop { | +// | => input_nodes +// loop { | | +// dynamic-slice dynamic-slice +// ... ... +// } } +// +// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice +// with the input nodes. +Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( + HloInstruction* loop, int64 non_windowed_operand_index) { + auto input_tuple = loop->mutable_operand(0); + auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index); + auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand); + auto to_sink = std::move(input_nodes.first); + auto new_operands = std::move(input_nodes.second); + if (to_sink.empty()) { + return Status::OK(); + } + auto computation = loop->parent(); + // Replace the old operand with a tuple of the found small operands. + auto new_input_subtuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape( + non_windowed_operand_index, new_input_subtuple)); + + auto body = loop->while_body(); + auto body_param = body->parameter_instruction(0); + auto old_body_param_users = body_param->users(); + // Update all tuple shapes. + for (auto tuple : std::vector{ + input_tuple, loop, loop->while_condition()->parameter_instruction(0), + body_param, body->root_instruction()}) { + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), + {non_windowed_operand_index}) = + new_input_subtuple->shape(); + } + // Now update the loop body. + auto new_operand_tuple_inside = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_input_subtuple->shape(), body_param, non_windowed_operand_index)); + TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape( + non_windowed_operand_index, new_operand_tuple_inside)); + + // Create nodes inside the loop body. + std::vector worklist; + absl::flat_hash_map outside_to_inside; + auto add_users_if_available = [&](HloInstruction* inst) { + for (auto u : inst->users()) { + if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 && + absl::c_all_of(u->operands(), [&](const HloInstruction* o) { + return outside_to_inside.count(o) > 0; + })) { + worklist.push_back(u); + } + } + }; + for (int64 i = 0; i < new_operands.size(); ++i) { + outside_to_inside[new_operands[i]] = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_operands[i]->shape(), new_operand_tuple_inside, i)); + add_users_if_available(new_operands[i]); + } + // HLOs to sink without operands. + std::vector nullaries_to_sink; + for (auto inst : to_sink) { + if (inst->operand_count() == 0) { + nullaries_to_sink.push_back(inst); + } + } + // Sort nullaries_to_sink to make it deterministic. + absl::c_sort(nullaries_to_sink, + [](const HloInstruction* a, const HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + worklist.reserve(nullaries_to_sink.size()); + for (auto inst : nullaries_to_sink) { + worklist.push_back(inst); + } + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + std::vector inst_new_operands(inst->operand_count()); + for (int64 i = 0; i < inst->operand_count(); ++i) { + inst_new_operands[i] = outside_to_inside[inst->operand(i)]; + } + outside_to_inside[inst] = body->AddInstruction( + inst->CloneWithNewOperands(inst->shape(), inst_new_operands)); + add_users_if_available(inst); + } + TF_RET_CHECK(outside_to_inside.count(old_operand) > 0); + for (auto ou : old_body_param_users) { + if (ou->opcode() == HloOpcode::kGetTupleElement && + ou->tuple_index() == non_windowed_operand_index) { + TF_RETURN_IF_ERROR( + ou->ReplaceAllUsesWith(outside_to_inside[old_operand])); + TF_RETURN_IF_ERROR(body->RemoveInstruction(ou)); + } + } + return Status::OK(); +} + +// Moves a cluster of memory-reducing nodes (with reduce nodes at the end) into +// the windowed dot-general loop on non-contracting dimensions. Such a loop has +// a dynamic-update-slice at the output. If we move the user nodes into the loop +// and before the dynamic-update-slice, the user nodes can operate on smaller +// shapes, which reduces memory. +// +// small_operands small_operands +// | | => | | +// | | loop { loop { | | +// | | conv | broadcast conv +// | | | | | / +// | | dynamic-update-slice | dynamic-slice / +// | | | | | / +// | | } | | multiply----- +// |broadcast / | / +// | | / reduce +// |multiply-- | +// \ | dynamic-update-slice +// reduce } +// +// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice +// with the input nodes (broadcast). +Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( + HloInstruction* loop) { + CHECK_EQ(loop->user_count(), 1); + // There should be a single direct user of the while loop, which is the + // gte for element 2, i.e., the dot output. + auto user_gte = loop->users().front(); + CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement); + CHECK_EQ(user_gte->tuple_index(), 2); + auto computation = loop->parent(); + + // Find the reduce outputs and the input nodes they depend on, if input nodes + // only have small operands. + absl::flat_hash_set to_move; + std::vector new_operands; + absl::flat_hash_set new_operands_set; + std::vector reduce_outputs; + std::vector worklist; + Shape padded_shape = user_gte->shape(); + Shape unpadded_shape = user_gte->shape(); + auto original_output = user_gte; + + if (user_gte->user_count() == 1 && + user_gte->users().back()->opcode() == HloOpcode::kSlice) { + original_output = user_gte->users().back(); + unpadded_shape = original_output->shape(); + } + for (auto u : original_output->users()) { + worklist.push_back(u); + } + to_move.insert(original_output); + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (to_move.count(inst) > 0) { + continue; + } + // We only support reduces with simple reduction function, since we may need + // to accumulate across iterations manually. + if (inst->opcode() == HloOpcode::kReduce && + inst->to_apply()->instruction_count() == 3 && + inst->to_apply()->num_parameters() == 2 && + inst->to_apply()->root_instruction()->IsElementwise()) { + to_move.insert(inst); + auto other_operand = inst->mutable_operand(1); + auto res = new_operands_set.emplace(other_operand); + if (res.second) { + new_operands.push_back(other_operand); + } + reduce_outputs.push_back(inst); + } else if (inst != computation->root_instruction() && + inst->user_count() > 0 && inst->IsElementwise() && + !inst->HasSideEffectNoRecurse() && + inst->opcode() != HloOpcode::kAllReduce && + absl::c_all_of(inst->operands(), + [inst](const HloInstruction* o) { + return ShapeUtil::CompatibleIgnoringElementType( + o->shape(), inst->shape()); + })) { + // For an elementwise op, we need to make sure that they depend on only + // nodes already in to_move and nodes with small operands. + bool can_include = true; + for (auto operand : inst->operands()) { + if (to_move.count(operand) > 0) { + continue; + } + auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand); + if (find_result.first.empty()) { + can_include = false; + break; + } + for (auto n : find_result.first) { + to_move.insert(n); + } + for (auto new_operand : find_result.second) { + auto res = new_operands_set.insert(new_operand); + if (res.second) { + new_operands.push_back(new_operand); + } + } + } + if (!can_include) { + to_move.clear(); + break; + } + to_move.insert(inst); + for (auto u : inst->users()) { + worklist.push_back(u); + } + } else { + to_move.clear(); + break; + } + } + // If nothing is found, to_move could contain only original_output, or cleared + // by the above code. + if (to_move.size() <= 1) { + return Status::OK(); + } + + // We will replace the original loop output with reduce-shape outputs. Create + // the initial buffers before the loop. + for (auto out : reduce_outputs) { + auto padded_out_shape = out->shape(); + int64 operand_dim = 0; + int64 output_dim = 0; + while (output_dim < padded_out_shape.rank()) { + if (absl::c_linear_search(out->dimensions(), operand_dim)) { + // Dimension colapsed. + ++operand_dim; + continue; + } + // Kept dimensions have the same size of the padded shape. + padded_out_shape.set_dimensions(output_dim, + padded_shape.dimensions(operand_dim)); + ++operand_dim; + ++output_dim; + } + auto broadcast = + computation->AddInstruction(HloInstruction::CreateBroadcast( + padded_out_shape, + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(out->shape().element_type()))), + {})); + new_operands.push_back(broadcast); + } + + auto input_tuple = loop->mutable_operand(0); + // Create the new input subtuple that contains the small operands and the + // reduce-shape result buffers. + auto new_input_subtuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple)); + auto body = loop->while_body(); + auto body_param = body->parameter_instruction(0); + auto body_root = body->root_instruction(); + CHECK_EQ(body_root->opcode(), HloOpcode::kTuple); + // Update tuple shapes. + for (auto tuple : std::vector{ + input_tuple, loop, loop->while_condition()->parameter_instruction(0), + body_param, body_root}) { + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) = + new_input_subtuple->shape(); + } + auto new_loop_input = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_input_subtuple->shape(), body_param, 2)); + + // Now create the moved nodes inside the loop body. + absl::flat_hash_map outside_to_inside; + worklist.clear(); + auto add_users_if_available = [&](HloInstruction* inst) { + for (auto u : inst->users()) { + if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 && + absl::c_all_of(u->operands(), [&](const HloInstruction* o) { + return outside_to_inside.count(o) > 0; + })) { + worklist.push_back(u); + } + } + }; + for (int64 i = 0; i < new_operands.size(); ++i) { + outside_to_inside[new_operands[i]] = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + new_operands[i]->shape(), new_loop_input, i)); + add_users_if_available(new_operands[i]); + } + // The elementwise nodes will be created with sliced shape. The original loop + // output corresponds to the dynamic-update-slice's update slice. + auto dus = body_root->mutable_operand(2); + CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice); + outside_to_inside[original_output] = dus->mutable_operand(1); + add_users_if_available(original_output); + std::vector slice_offsets(padded_shape.rank()); + for (int64 i = 0; i < slice_offsets.size(); ++i) { + slice_offsets[i] = dus->mutable_operand(i + 2); + } + auto get_slice = [&](HloInstruction* padded) { + return body->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::ChangeElementType(dus->operand(1)->shape(), + padded->shape().element_type()), + padded, slice_offsets, dus->operand(1)->shape().dimensions())); + }; + // Helper functions to create nodes with small operands. + auto add_broadcast = [&](const HloInstruction* broadcast) { + auto padded_operand_shape = broadcast->operand(0)->shape(); + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + padded_operand_shape.set_dimensions( + i, padded_shape.dimensions(broadcast->dimensions(i))); + } + auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)], + padded_operand_shape, nullptr, body); + outside_to_inside[broadcast] = + get_slice(body->AddInstruction(broadcast->CloneWithNewOperands( + ShapeUtil::ChangeElementType(padded_shape, + padded_operand_shape.element_type()), + {padded_operand}))); + }; + auto add_iota = [&](const HloInstruction* iota) { + outside_to_inside[iota] = + get_slice(body->AddInstruction(iota->CloneWithNewOperands( + ShapeUtil::ChangeElementType(padded_shape, + iota->shape().element_type()), + {}))); + }; + auto add_constant = [&](const HloInstruction* constant) { + outside_to_inside[constant] = body->AddInstruction(constant->Clone()); + outside_to_inside[constant] = get_slice( + PadToShape(outside_to_inside[constant], + ShapeUtil::ChangeElementType( + padded_shape, constant->shape().element_type()), + nullptr, body)); + }; + while (!worklist.empty()) { + auto inst = worklist.back(); + worklist.pop_back(); + if (outside_to_inside.count(inst) > 0) { + continue; + } + if (inst->opcode() == HloOpcode::kBroadcast) { + add_broadcast(inst); + } else if (inst->opcode() == HloOpcode::kIota) { + add_iota(inst); + } else if (inst->opcode() == HloOpcode::kConstant) { + add_constant(inst); + } else if (inst->opcode() == HloOpcode::kReduce) { + // This is an output, for which we has special handling later. + } else { + std::vector operands_inside(inst->operand_count()); + for (int64 i = 0; i < operands_inside.size(); ++i) { + operands_inside[i] = outside_to_inside[inst->operand(i)]; + } + outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands( + ShapeUtil::ChangeElementType(dus->operand(1)->shape(), + inst->shape().element_type()), + operands_inside)); + } + add_users_if_available(inst); + } + std::vector new_outputs_inside(new_operands.size()); + for (int64 i = 0; i < new_outputs_inside.size(); ++i) { + new_outputs_inside[i] = outside_to_inside[new_operands[i]]; + } + // Now create the reduce outpus inside of the loop. + for (int64 i = 0; i < reduce_outputs.size(); ++i) { + auto reduce_outside = reduce_outputs[i]; + CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce); + int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; + auto last_iter_result = outside_to_inside[new_operands[index_in_operand]]; + auto operand0 = outside_to_inside[reduce_outside->operand(0)]; + auto operand1 = outside_to_inside[reduce_outside->operand(1)]; + TF_ASSIGN_OR_RETURN(auto reduce_shape, + ShapeInference::InferReduceShape( + {&operand0->shape(), &operand1->shape()}, + reduce_outside->dimensions(), + reduce_outside->to_apply()->ComputeProgramShape())); + *reduce_shape.mutable_layout() = reduce_outside->shape().layout(); + std::vector reduce_dus_offsets; + // If any collapsed dimension is windowed, we need to accumulate with last + // iteration's result. If such a dimension has padding, we also need to mask + // off invalid data. + bool needs_accumulate = false; + std::vector dims_to_mask; + for (int64 i = 0; i < slice_offsets.size(); ++i) { + if (absl::c_linear_search(reduce_outside->dimensions(), i)) { + if (reduce_outside->operand(0)->shape().dimensions(i) != + operand0->shape().dimensions(i)) { + needs_accumulate = true; + if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) { + dims_to_mask.push_back(i); + } + } + continue; + } + reduce_dus_offsets.push_back(slice_offsets[i]); + } + // Mask off invalid data in collapsed dimensions. + for (int64 dim : dims_to_mask) { + auto iota = body->AddInstruction(HloInstruction::CreateIota( + ShapeUtil::ChangeElementType(operand0->shape(), S32), dim)); + auto add = body->AddInstruction(HloInstruction::CreateBinary( + iota->shape(), HloOpcode::kAdd, iota, + body->AddInstruction(HloInstruction::CreateBroadcast( + iota->shape(), slice_offsets[dim], {})))); + auto limit = body->AddInstruction(HloInstruction::CreateBroadcast( + iota->shape(), + body->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0( + reduce_outside->operand(0)->shape().dimensions(dim)))), + {})); + auto compare = body->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit, + ComparisonDirection::kLt)); + operand0 = body->AddInstruction(HloInstruction::CreateTernary( + operand0->shape(), HloOpcode::kSelect, compare, operand0, + body->AddInstruction(HloInstruction::CreateBroadcast( + operand0->shape(), operand1, {})))); + } + auto output_inside = + body->AddInstruction(reduce_outside->CloneWithNewOperands( + reduce_shape, {operand0, operand1})); + // Accumulate with previous results if needed. + if (needs_accumulate) { + auto input_slice = + body->AddInstruction(HloInstruction::CreateDynamicSlice( + output_inside->shape(), last_iter_result, reduce_dus_offsets, + output_inside->shape().dimensions())); + output_inside = body->AddInstruction(HloInstruction::CreateBinary( + output_inside->shape(), + reduce_outside->to_apply()->root_instruction()->opcode(), + output_inside, input_slice)); + } + // Dynamic-update-slice if needed. + if (!ShapeUtil::Compatible(output_inside->shape(), + last_iter_result->shape())) { + output_inside = + body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + last_iter_result->shape(), last_iter_result, output_inside, + reduce_dus_offsets)); + } + new_outputs_inside[index_in_operand] = output_inside; + } + // Body output. + auto new_output_inside = + body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside)); + TF_RETURN_IF_ERROR( + body_root->ReplaceOperandWithDifferentShape(2, new_output_inside)); + TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus)); + // Replace uses of the reduces outside the loop. + auto new_output_gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_output_inside->shape(), loop, 2)); + for (int64 i = 0; i < reduce_outputs.size(); ++i) { + int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; + auto new_output = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_outputs_inside[index_in_operand]->shape(), new_output_gte, + index_in_operand)); + if (!ShapeUtil::Compatible(new_output->shape(), + reduce_outputs[i]->shape())) { + new_output = computation->AddInstruction(HloInstruction::CreateSlice( + reduce_outputs[i]->shape(), new_output, + std::vector(new_output->shape().rank(), 0), + reduce_outputs[i]->shape().dimensions(), + std::vector(new_output->shape().rank(), 1))); + } + TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output)); + TF_RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i])); + } + return Status::OK(); +} + +} // namespace + +Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops( + HloComputation* computation) { + for (auto& loop : windowed_dot_general_loops_) { + if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims) { + // We have a dynamic-slice for the non-windowed operand in + // batch/contracting-dim windowed dot-general. So moving the + // broadcast/iota/elementwise ops into the loop could help reduce memory + // via fusion. + TF_RETURN_IF_ERROR( + SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( + loop.while_loop, 1 - loop.windowed_operand)); + } + if (!loop.windowed_in_contracting_dims) { + // We have a dynamic-update-slice for the output in + // batch/non-contracting-dim windowed dot-general. So moving reduce ops + // into the loop could help reduce memory. + TF_RETURN_IF_ERROR( + MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( + loop.while_loop)); + } + } + return Status::OK(); +} + +} // namespace spmd +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index 7e136be54e6..8006e47d90d 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" -#include "tensorflow/compiler/xla/service/dot_as_convolution_util.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -166,28 +165,6 @@ template namespace { -// Returns the replica group configuration where each replica belongs to its own -// group. -std::vector CreateReplicaGroups(int64 num_replicas) { - std::vector groups(num_replicas); - for (int64 i = 0; i < num_replicas; ++i) { - groups[i].add_replica_ids(i); - } - return groups; -} - -bool CanReshardWithAllToAll(const HloSharding& source, - const HloSharding& target) { - return UniqueTiledDim(source) && UniqueTiledDim(target) && - UniqueTiledDim(source) != UniqueTiledDim(target); -} - -bool CanReshardWithCollectivePermute(const HloSharding& source, - const HloSharding& target) { - return UniqueTiledDim(source) && UniqueTiledDim(target) && - UniqueTiledDim(source) == UniqueTiledDim(target) && source != target; -} - // Clears all sharding attributes from instructions in the module. This must be // called only after all SPMD transformation is complete. Status ClearShardingAttributes(HloModule* module) { @@ -208,6 +185,28 @@ Status ClearShardingAttributes(HloModule* module) { return Status::OK(); } +std::vector> GetPartitionGroupsForReplication( + const HloSharding& sharding, absl::Span replication_dims) { + int64 group_size = 1; + for (int64 i : replication_dims) { + group_size *= sharding.tile_assignment().dim(i); + } + std::vector> partition_groups( + sharding.tile_assignment().num_elements() / group_size); + sharding.tile_assignment().Each( + [&](absl::Span indices, int64 partition) { + int64 group_id = 0; + for (int64 i = 0; i < indices.size(); ++i) { + if (!absl::c_linear_search(replication_dims, i)) { + group_id *= sharding.tile_assignment().dim(i); + group_id += indices[i]; + } + } + partition_groups[group_id].push_back(partition); + }); + return partition_groups; +} + } // namespace HloInstruction* SpmdBuilder::AddInstruction( @@ -278,8 +277,80 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { return ReshardWithCollectivePermute(target); } - if (CanReshardWithAllToAll(sharding(), target)) { - return ReshardWithAllToAll(target); + if (auto src_tgt_dims = + GetReshardAllToAllSourceTargetDims(sharding(), target)) { + return ReshardWithAllToAll(target, *src_tgt_dims); + } + + // Partial replicated to tiled. + if (sharding().ReplicateOnLastTileDim() && !target.ReplicateOnLastTileDim() && + !target.IsTileMaximal()) { + // Get the temp sharding target from partial replicate to target tile dims. + // target_compatible_sharding has the same tile_assignment dimensions + // as the target and can reshard to target by collective permute. + // target_compatible_sharding could have different device assignment as + // targe. sharding() can reshard to target_compatible_sharding by + // dynamic slice. + auto target_compatible_sharding = PartialReplicateToTileCompatibleSharding( + sharding(), target.tile_assignment().dimensions()); + // Reshard to target_compatible_sharding by dynamic slice. + if (target_compatible_sharding.has_value()) { + std::vector expand_tile_dims; + std::vector tiling_dim_factors; + int64 rank = shape.rank(); + tiling_dim_factors.reserve(rank); + auto temp_target_sharding = target_compatible_sharding.value(); + for (int64 dim = 0; dim < rank; dim++) { + if (temp_target_sharding.tile_assignment().dim(dim) > + sharding().tile_assignment().dim(dim)) { + expand_tile_dims.push_back(dim); + } + tiling_dim_factors.emplace_back( + temp_target_sharding.tile_assignment().dim(dim) / + sharding().tile_assignment().dim(dim)); + } + + // Get per_group partitioner state. + std::vector group_dims( + sharding().tile_assignment().num_dimensions() - 1); + std::iota(group_dims.begin(), group_dims.end(), 0); + auto sharding_grouped = GroupShardingOnDims(sharding(), group_dims); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + state_, sharding_grouped.device_groups, state_.b); + // 2. Get the padded_hlo, do right halo exchange if needed. + auto padded_hlo = PadFromPartialReplicateShape( + hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims, + state_.collective_ops_creator, state_.next_channel_id, + state_.partition_id, state_.b); + if (padded_hlo.has_value()) { + // 3. Slice out the tile from replicate ones. + auto shard_shape = + MakePartitionedShape(base_shape_, temp_target_sharding); + // device assignment within each group is sorted in + // HloSharding::PartialTile, thus partiton_id within each group can be + // matched with the order in tile_assignment. + Array tiling_assignment(tiling_dim_factors); + tiling_assignment.FillIota(0); + auto slice = + state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo.value(), + MakePartitionOffsets(padded_hlo.value()->shape(), + HloSharding::Tile(tiling_assignment), + per_group_partitioner_state.partition_id, + per_group_partitioner_state.b), + shard_shape.dimensions())); + slice->set_sharding(temp_target_sharding); + auto result = PartitionedHlo(slice, base_shape_, state_); + // If temp_target_sharding's device assignment is different from target, + // use collective permute to reshard. + if (CanReshardWithCollectivePermute(temp_target_sharding, target)) { + return result.ReshardWithCollectivePermute(target); + } + // If device assignment in temp_target_sharding and target are the same, + // return result directly. + return result; + } + } } // If not replicated yet, first replicate and then reshard to use one of the @@ -296,6 +367,19 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { return PartitionedHlo(copy, base_shape_, state_); } + // 'Replicated' to partial replicated. + if (target.ReplicateOnLastTileDim()) { + std::vector group_dims(target.tile_assignment().num_dimensions() - + 1); + std::iota(group_dims.begin(), group_dims.end(), 0); + auto target_grouped = GroupShardingOnDims(target, group_dims); + auto partially_sharded = PerGroupSliceFromReplicated( + hlo_, state_.partition_id, target_grouped.device_groups, group_dims, + target_grouped.group_dim_sizes, state_.b); + partially_sharded->set_sharding(target); + return PartitionedHlo(partially_sharded, base_shape(), state_); + } + // 'Replicated' to 'Tiled'. auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); @@ -676,42 +760,57 @@ PartitionedHlo PartitionedHlo::Replicate() { } // 'Tiled' to 'Replicated'. + std::vector all_dims(shape.rank()); + std::iota(all_dims.begin(), all_dims.end(), 0); + HloInstruction* result = ReplicatePartial(all_dims); + result->set_sharding(HloSharding::Replicate()); + return update_cache(PartitionedHlo(result, base_shape_, state_)); +} + +HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span dims) { + CHECK(!sharding().IsTileMaximal()); + const Shape& shard_shape = hlo()->shape(); + Shape target_shape = shard_shape; + Shape padded_target_shape = shard_shape; + for (int64 i : dims) { + padded_target_shape.set_dimensions( + i, shard_shape.dimensions(i) * sharding().tile_assignment().dim(i)); + target_shape.set_dimensions(i, base_shape().dimensions(i)); + } + HloInstruction* result = nullptr; if (state_.collective_ops_creator.create_cross_partition_all_gather) { - result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding, - NewChannel()); - } - Shape padded_base_shape = shape; - for (int64 i = 0; i < padded_base_shape.rank(); ++i) { - padded_base_shape.set_dimensions( - i, shape.dimensions(i) * sharding.tile_assignment().dim(i)); + result = state_.partitioner->AllGatherShards(state_.b, hlo_, sharding(), + NewChannel(), dims, + state_.collective_ops_creator); } if (result == nullptr) { auto zero = state_.b->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(shape.element_type()))); + LiteralUtil::Zero(shard_shape.element_type()))); auto zero_bcast = state_.b->AddInstruction( - HloInstruction::CreateBroadcast(padded_base_shape, zero, {})); + HloInstruction::CreateBroadcast(padded_target_shape, zero, {})); + auto offsets = MakePartitionOffsets(padded_target_shape, sharding(), + state_.partition_id, state_.b, dims); auto dus = state_.b->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - padded_base_shape, zero_bcast, hlo_, - MakePartitionOffsets(padded_base_shape, sharding, - state_.partition_id, state_.b))); + padded_target_shape, zero_bcast, hlo_, offsets)); HloComputation* reduction = - MakeBinaryAdd(shape.element_type(), state_.module); + MakeBinaryAdd(shard_shape.element_type(), state_.module); auto all_reduce = state_.collective_ops_creator.create_cross_partition_all_reduce( - state_.b, dus, reduction, NewChannel()); + state_.b, dus, reduction, + GetPartitionGroupsForReplication(sharding(), dims), NewChannel()); result = all_reduce; } - if (!ShapeUtil::Compatible(base_shape_, padded_base_shape)) { - std::vector start_indices(shape.rank(), 0); - std::vector strides(shape.rank(), 1); - result = state_.b->AddInstruction(HloInstruction::CreateSlice( - base_shape_, result, start_indices, base_shape_.dimensions(), strides)); + if (!ShapeUtil::Compatible(target_shape, padded_target_shape)) { + std::vector start_indices(target_shape.rank(), 0); + std::vector strides(target_shape.rank(), 1); + result = state_.b->AddInstruction( + HloInstruction::CreateSlice(target_shape, result, start_indices, + base_shape_.dimensions(), strides)); } - result->set_sharding(HloSharding::Replicate()); - return update_cache(PartitionedHlo(result, base_shape_, state_)); + return result; } PartitionedHlo PartitionedHlo::Broadcast() const { @@ -740,50 +839,101 @@ PartitionedHlo PartitionedHlo::Broadcast() const { MakeBinaryAdd(shape.element_type(), state_.module); auto result = state_.collective_ops_creator.create_cross_partition_all_reduce( - state_.b, operand, reduction, NewChannel()); + state_.b, operand, reduction, {}, NewChannel()); result->set_sharding(HloSharding::Replicate()); return PartitionedHlo(result, base_shape_, state_); } PartitionedHlo PartitionedHlo::ReshardWithAllToAll( - const HloSharding& target) const { - int64 partition_count = sharding().tile_assignment().num_elements(); - absl::optional input_partition_dim = UniqueTiledDim(sharding()); - absl::optional output_partition_dim = UniqueTiledDim(target); - CHECK(input_partition_dim.has_value()); - CHECK(output_partition_dim.has_value()); - - // If the device order is different in the target, fix the order with - // ReshardWithCollectivePermute. - auto input_tile_fixed_device_order = target.tile_assignment(); - input_tile_fixed_device_order.Reshape( - sharding().tile_assignment().dimensions()); - auto input_sharding_fixed_device_order = - HloSharding::Tile(input_tile_fixed_device_order); - if (input_sharding_fixed_device_order != sharding()) { - auto fixed_order = - ReshardWithCollectivePermute(input_sharding_fixed_device_order); - return fixed_order.ReshardWithAllToAll(target); + const HloSharding& target, + absl::Span> source_target_dims) const { + if (source_target_dims.empty()) { + if (target == sharding()) { + return *this; + } + // If the device order is different in the target, fix the order with + // ReshardWithCollectivePermute. + return ReshardWithCollectivePermute(target); } - auto padded_hlo = - PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b); + // Swap one pair of dimensions. + int64 source_dim = source_target_dims[0].first; + int64 target_dim = source_target_dims[0].second; + const int64 group_size = sharding().tile_assignment().dim(source_dim) / + sharding().tile_assignment().dim(target_dim); - // The order of ids in the group must follow the target sharding. - std::vector groups(1); - for (int64 device : target.tile_assignment()) { - groups[0].add_replica_ids(device); + auto temp_target_tile = sharding().tile_assignment(); + { + std::vector reshape_tile_dims(temp_target_tile.num_dimensions() + 2); + int64 i = 0; + int64 added_source_dim = -1; + int64 added_target_dim = -1; + for (int64 j = 0; j < temp_target_tile.num_dimensions(); ++j) { + if (source_dim == j) { + reshape_tile_dims[i] = temp_target_tile.dim(j) / group_size; + reshape_tile_dims[++i] = group_size; + added_source_dim = i; + } else if (target_dim == j) { + reshape_tile_dims[i] = temp_target_tile.dim(j); + reshape_tile_dims[++i] = 1; + added_target_dim = i; + } else { + reshape_tile_dims[i] = temp_target_tile.dim(j); + } + ++i; + } + temp_target_tile.Reshape(reshape_tile_dims); + std::vector xpose_dims(temp_target_tile.num_dimensions()); + std::iota(xpose_dims.begin(), xpose_dims.end(), 0); + xpose_dims[added_source_dim] = added_target_dim; + xpose_dims[added_target_dim] = added_source_dim; + temp_target_tile = hlo_sharding_util::TransposeSharding( + HloSharding::Tile(temp_target_tile), xpose_dims) + .tile_assignment(); + auto temp_target_tile_dims = sharding().tile_assignment().dimensions(); + temp_target_tile_dims[source_dim] = + sharding().tile_assignment().dim(target_dim); + temp_target_tile_dims[target_dim] = + sharding().tile_assignment().dim(source_dim); + temp_target_tile.Reshape(temp_target_tile_dims); } + auto temp_target = target.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(temp_target_tile) + : HloSharding::Tile(temp_target_tile); + auto padded_shape = hlo_->shape(); + padded_shape.set_dimensions( + target_dim, + RoundUpToNearest(padded_shape.dimensions(target_dim), + temp_target.tile_assignment().dim(target_dim))); + auto padded_hlo = PadToShape(hlo_, padded_shape, state_.b); + + // The order of ids in the group must follow the temp_target sharding. + std::vector> groups( + temp_target.tile_assignment().num_elements() / group_size); + temp_target.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + int64 group_id = 0; + for (int64 dim = 0; dim < indices.size(); ++dim) { + if (dim == target_dim) { + group_id *= temp_target.tile_assignment().dim(dim) / group_size; + group_id += indices[dim] / group_size; + } else { + group_id *= temp_target.tile_assignment().dim(dim); + group_id += indices[dim]; + } + } + groups[group_id].push_back(device); + }); HloInstruction* result = nullptr; - // Split along the split dimension (output_partition_dim) of the all-to-all + // Split along the split dimension (target_dim) of the all-to-all // output. std::vector dimensions; for (int64 i = 0; i < base_shape_.rank(); ++i) { - if (i == *output_partition_dim) { - dimensions.push_back(partition_count); - dimensions.push_back(padded_hlo->shape().dimensions(i) / partition_count); + if (i == target_dim) { + dimensions.push_back(group_size); + dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size); } else { dimensions.push_back(padded_hlo->shape().dimensions(i)); } @@ -794,21 +944,19 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( // After the reshape, it is guaranteed to have at least 3 dimensions. auto all_to_all = state_.collective_ops_creator.create_cross_partition_all_to_all( - state_.b, {reshape}, groups, (*state_.next_channel_id)++, - output_partition_dim); + state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim); // Reorder the split dimension of the reshape to be located in front of the // input partition dimension, so the two dimensions can be combined. - int64 new_input_partition_dim = (*output_partition_dim < *input_partition_dim) - ? *input_partition_dim + 1 - : *input_partition_dim; + int64 new_source_dim = + (target_dim < source_dim) ? source_dim + 1 : source_dim; std::vector permutation; for (int64 i = 0; i < all_to_all->shape().rank(); ++i) { - if (i == *output_partition_dim) { + if (i == target_dim) { continue; } - if (i == new_input_partition_dim) { - permutation.push_back(*output_partition_dim); + if (i == new_source_dim) { + permutation.push_back(target_dim); } permutation.push_back(i); } @@ -819,32 +967,33 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll( // Combine the split dimension and the input partition dimension. auto new_shape = ShapeInference::InferAllToAllShape( - padded_hlo->shape(), *output_partition_dim, - *input_partition_dim, partition_count) + padded_hlo->shape(), target_dim, source_dim, group_size) .ValueOrDie(); result = state_.b->AddInstruction( HloInstruction::CreateReshape(new_shape, transpose)); - const Shape result_shape = MakePartitionedShape(base_shape_, target); + const Shape result_shape = MakePartitionedShape(base_shape_, temp_target); if (result_shape != result->shape()) { result = state_.b->AddInstruction(HloInstruction::CreateSlice( result_shape, result, std::vector(result_shape.rank(), 0), result_shape.dimensions(), std::vector(result_shape.rank(), 1))); } - result->set_sharding(target); - return PartitionedHlo(result, base_shape_, state_); + result->set_sharding(temp_target); + auto remaining_source_target_dims = source_target_dims; + remaining_source_target_dims.remove_prefix(1); + return PartitionedHlo(result, base_shape_, state_) + .ReshardWithAllToAll(target, remaining_source_target_dims); } PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( const HloSharding& target) const { - CHECK(CanReshardWithCollectivePermute(sharding(), target)); + CHECK(CanReshardWithCollectivePermute(sharding(), target)) + << sharding().ToString() << " to " << target.ToString(); std::vector> src_dst_pairs; sharding().tile_assignment().Each( [&](absl::Span indices, int64 src_device) { int64 dst_device = target.tile_assignment()(indices); - if (dst_device != src_device) { - src_dst_pairs.emplace_back(src_device, dst_device); - } + src_dst_pairs.emplace_back(src_device, dst_device); }); auto cp = state_.collective_ops_creator.create_cross_partition_collective_permute( @@ -990,7 +1139,7 @@ Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { offset += operand->shape().dimensions(dimension); } auto all_reduce = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), + &b_, temp_output, MakeBinaryAdd(hlo->shape().element_type(), module_), {}, NewChannel()); SetPartitionedHlo(hlo, [&] { auto start_indices = @@ -1005,47 +1154,7 @@ Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) { return Status::OK(); } -// If partitioning in the operand only happens in dimensions in passthrough -// dimensions (offset dimensions in the gather output (or scatter update) that -// have the same size as the operand), returns the corresponding output (or -// update) sharding by passing through the input sharding. -absl::optional PassthroughOperandToGatherOutputOrScatterUpdate( - const PartitionedHlo& operand, const Shape& update_or_gather_shape, - absl::Span collapsed_or_inserted_dims, - absl::Span index_map, - absl::Span offset_or_window_dims, - absl::Span slice_size) { - if (operand.sharding().IsTileMaximal()) { - return operand.sharding(); - } - std::vector passthrough_tile(update_or_gather_shape.rank(), 1); - int64 collapsed = 0; - for (int64 i = 0; i < operand.base_shape().rank(); ++i) { - int64 dim_partitions = operand.sharding().tile_assignment().dim(i); - if (absl::c_linear_search(collapsed_or_inserted_dims, i) || - absl::c_linear_search(index_map, i)) { - if (dim_partitions > 1) { - return absl::nullopt; - } - collapsed++; - continue; - } - if (slice_size[i] != operand.base_shape().dimensions(i) && - dim_partitions > 1) { - return absl::nullopt; - } - int64 offset_dim = offset_or_window_dims[i - collapsed]; - if (i - collapsed > 0 && - offset_dim < offset_or_window_dims[i - collapsed - 1]) { - // Output offsets are transposed, we do not support this case. - return absl::nullopt; - } - passthrough_tile[offset_dim] = dim_partitions; - } - Array tile_assignment = operand.sharding().tile_assignment(); - tile_assignment.Reshape(passthrough_tile); - return HloSharding::Tile(tile_assignment); -} +namespace { // Returns whether partitioning in the operand only happens in dimensions with // gather/scatter slice size 1. @@ -1140,6 +1249,8 @@ IndexBoundsForGatherScatterOperandPartitionedOnTrivialSliceDims( return {broadcast_min, broadcast_max}; } +} // namespace + Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { auto scatter = Cast(hlo); auto dnums = scatter->scatter_dimension_numbers(); @@ -1155,17 +1266,87 @@ Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { slice_size[i] = updates.base_shape().dimensions( dnums.update_window_dims(num_update_window_dims++)); } - std::vector inserted_window_dims(dnums.inserted_window_dims().begin(), - dnums.inserted_window_dims().end()); std::vector scatter_dims_to_operand_dims( dnums.scatter_dims_to_operand_dims().begin(), dnums.scatter_dims_to_operand_dims().end()); - std::vector update_window_dims(dnums.update_window_dims().begin(), - dnums.update_window_dims().end()); - if (!operand.sharding().IsTileMaximal()) { - auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate( - operand, updates.base_shape(), inserted_window_dims, - scatter_dims_to_operand_dims, update_window_dims, slice_size); + std::vector update_scatter_dims; + for (int64 i = 0; i < updates.base_shape().rank(); ++i) { + if (!absl::c_linear_search(dnums.update_window_dims(), i)) { + update_scatter_dims.push_back(i); + } + } + if (operand.sharding().IsTileMaximal()) { + if (!indices.sharding().IsTileMaximal() && + (dnums.index_vector_dim() == indices.base_shape().rank() || + indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) == + 1)) { + auto reduction_opcode = ParseReductionComputation(scatter->to_apply()); + if (!reduction_opcode.has_value()) { + return DefaultAction(hlo); + } + HloInstruction* identity; + switch (*reduction_opcode) { + case HloOpcode::kAdd: + case HloOpcode::kOr: + identity = CreateZero(operand.hlo()->shape(), &b_); + break; + case HloOpcode::kMultiply: + case HloOpcode::kAnd: + identity = CreateOne(operand.hlo()->shape(), &b_); + break; + case HloOpcode::kMinimum: + identity = CreateConstant( + operand.hlo()->shape(), + LiteralUtil::MaxValue(hlo->shape().element_type()), &b_); + break; + case HloOpcode::kMaximum: + identity = CreateConstant( + operand.hlo()->shape(), + LiteralUtil::MinValue(hlo->shape().element_type()), &b_); + break; + default: + return DefaultAction(hlo); + } + std::vector update_dim_to_index_dim(updates.base_shape().rank(), + -1); + std::vector index_dim_to_update_dim(indices.base_shape().rank(), + -1); + for (int64 i = 0; i < update_scatter_dims.size(); ++i) { + int64 indices_scatter_dim = i < dnums.index_vector_dim() ? i : i + 1; + update_dim_to_index_dim[update_scatter_dims[i]] = indices_scatter_dim; + index_dim_to_update_dim[indices_scatter_dim] = update_scatter_dims[i]; + } + auto new_updates_sharding = TransposeShardingWithCollapsedDims( + indices.sharding(), index_dim_to_update_dim, update_dim_to_index_dim); + CHECK(new_updates_sharding.has_value()); + updates = updates.Reshard(*new_updates_sharding); + // To avoid accumulating the initial operand multiple times during + // all-reduce, we use identity operands for all non-zero partitions. + auto not_partition_zero = b_.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeScalarShape(PRED), partition_id_)); + not_partition_zero = b_.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType(identity->shape(), PRED), + not_partition_zero, {})); + auto select_operand = + b_.AddInstruction(HloInstruction::HloInstruction::CreateTernary( + identity->shape(), HloOpcode::kSelect, not_partition_zero, + identity, operand.Replicate().hlo())); + auto pscatter = b_.AddInstruction(scatter->CloneWithNewOperands( + scatter->shape(), {select_operand, indices.hlo(), updates.hlo()})); + auto all_reduce = + collective_ops_creator_.create_cross_partition_all_reduce( + &b_, pscatter, scatter->to_apply(), {}, NewChannel()); + all_reduce->set_sharding(HloSharding::Replicate()); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(all_reduce, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } else { + auto maybe_passthrough = hlo_sharding_util::ScatterUpdateShardingFromOutput( + operand.sharding(), *hlo); // Handle pass through cases if we can use compatible sharding for update. if (maybe_passthrough.has_value()) { indices = indices.Reshard(HloSharding::Replicate()); @@ -1865,67 +2046,22 @@ Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) { auto& operand = GetPartitionedHlo(hlo->operand(0)); // Tiled output. - std::vector wanted_input_tile_size(operand.base_shape().rank()); - std::vector sharded_new_dims; - for (int64 i = 0; i < operand.base_shape().rank(); ++i) { - wanted_input_tile_size[i] = - hlo->sharding().tile_assignment().dim(hlo->dimensions(i)); - } + std::vector new_dims; for (int64 i = 0; i < hlo->shape().rank(); ++i) { - if (!absl::c_linear_search(hlo->dimensions(), i) && - hlo->sharding().tile_assignment().dim(i) > 1) { - sharded_new_dims.push_back(i); + if (!absl::c_linear_search(hlo->dimensions(), i)) { + new_dims.push_back(i); } } - if (sharded_new_dims.empty()) { - // The new dimensions are replicated, so that we can do the adjustment on - // the input. - Array wanted_input_tile_assignment(wanted_input_tile_size); - wanted_input_tile_assignment.Each( - [&](absl::Span indices, int64* val) { - std::vector indices_in_broadcast(hlo->shape().rank(), 0); - for (int64 i = 0; i < operand.base_shape().rank(); ++i) { - indices_in_broadcast[hlo->dimensions(i)] = indices[i]; - } - *val = hlo->sharding().tile_assignment()(indices_in_broadcast); - }); - SetPartitionedHlo(hlo, [&] { - return b_.AddInstruction(hlo->CloneWithNewOperands( - MakePartitionedShape(hlo->shape(), hlo->sharding()), - {operand.Reshard(HloSharding::Tile(wanted_input_tile_assignment)) - .hlo()})); - }); - } else { - auto input = operand.Reshard(HloSharding::Replicate()).hlo(); - // We pad and shard the input first, then broadcast to the final shard - // shape. - auto output_offsets = - MakePartitionOffsets(hlo->shape(), hlo->sharding(), partition_id_, &b_); - std::vector input_offsets(operand.base_shape().rank()); - auto output_shard_shape = - MakePartitionedShape(hlo->shape(), hlo->sharding()); - auto input_shard_shape = input->shape(); - auto padded_input_shape = input->shape(); - for (int64 i = 0; i < input_offsets.size(); ++i) { - input_offsets[i] = output_offsets[hlo->dimensions(i)]; - input_shard_shape.set_dimensions( - i, output_shard_shape.dimensions(hlo->dimensions(i))); - padded_input_shape.set_dimensions( - i, hlo->sharding().tile_assignment().dim(hlo->dimensions(i)) * - input_shard_shape.dimensions(i)); - } - auto padded_input = PadToShape(input, padded_input_shape, &b_); - auto input_shard = - ShapeUtil::Compatible(input_shard_shape, padded_input->shape()) - ? padded_input - : b_.AddInstruction(HloInstruction::CreateDynamicSlice( - input_shard_shape, padded_input, input_offsets, - input_shard_shape.dimensions())); - SetPartitionedHlo(hlo, [&] { - return b_.AddInstruction( - hlo->CloneWithNewOperands(output_shard_shape, {input_shard})); - }); - } + auto desired_input_sharding = hlo_sharding_util::RemoveShapeDimensions( + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(hlo->sharding(), + new_dims), + new_dims); + auto input = operand.Reshard(desired_input_sharding).hlo(); + auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); + SetPartitionedHlo(hlo, [&] { + return b_.AddInstruction( + hlo->CloneWithNewOperands(output_shard_shape, {input})); + }); return Status::OK(); } @@ -2019,16 +2155,50 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { const auto& dnums = gather->gather_dimension_numbers(); auto operand = GetPartitionedHlo(gather->operand(0)); auto indices = GetPartitionedHlo(gather->operand(1)); - std::vector collapsed_slice_dims(dnums.collapsed_slice_dims().begin(), - dnums.collapsed_slice_dims().end()); std::vector start_index_map(dnums.start_index_map().begin(), dnums.start_index_map().end()); - std::vector offset_dims(dnums.offset_dims().begin(), - dnums.offset_dims().end()); - if (!operand.sharding().IsTileMaximal()) { - auto maybe_passthrough = PassthroughOperandToGatherOutputOrScatterUpdate( - operand, gather->shape(), collapsed_slice_dims, start_index_map, - offset_dims, gather->gather_slice_sizes()); + std::vector batch_dims; + for (int64 i = 0; i < gather->shape().rank(); ++i) { + if (!absl::c_linear_search(dnums.offset_dims(), i)) { + batch_dims.push_back(i); + } + } + if (operand.sharding().IsTileMaximal()) { + if (!indices.sharding().IsTileMaximal() && + (dnums.index_vector_dim() == indices.base_shape().rank() || + indices.sharding().tile_assignment().dim(dnums.index_vector_dim()) == + 1)) { + auto replicated_operand = operand.Replicate(); + TF_ASSIGN_OR_RETURN( + Shape partitioned_output_shape, + ShapeInference::InferGatherShape(replicated_operand.hlo()->shape(), + indices.hlo()->shape(), dnums, + gather->gather_slice_sizes())); + auto pgather = b_.AddInstruction(gather->CloneWithNewOperands( + partitioned_output_shape, {replicated_operand.hlo(), indices.hlo()})); + std::vector output_dim_to_index_dim(pgather->shape().rank(), -1); + std::vector index_dim_to_output_dim(indices.base_shape().rank(), + -1); + for (int64 i = 0; i < batch_dims.size(); ++i) { + int64 indices_batch_dim = i < dnums.index_vector_dim() ? i : i + 1; + output_dim_to_index_dim[batch_dims[i]] = indices_batch_dim; + index_dim_to_output_dim[indices_batch_dim] = batch_dims[i]; + } + auto pgather_sharding = TransposeShardingWithCollapsedDims( + indices.sharding(), index_dim_to_output_dim, output_dim_to_index_dim); + CHECK(pgather_sharding.has_value()); + pgather->set_sharding(*pgather_sharding); + SetPartitionedHlo(hlo, [&]() { + return PartitionedHlo(pgather, hlo->shape(), MakePartitioningState()) + .Reshard(hlo->sharding()) + .hlo(); + }); + return Status::OK(); + } + } else { + auto maybe_passthrough = + hlo_sharding_util::GatherOutputShardingFromDataOperand( + operand.sharding(), *hlo); if (maybe_passthrough.has_value()) { indices = indices.Reshard(HloSharding::Replicate()); auto pshape = MakePartitionedShape(gather->shape(), *maybe_passthrough); @@ -2116,7 +2286,7 @@ Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) { // Combine from different partitions. auto ar = collective_ops_creator_.create_cross_partition_all_reduce( &b_, filtered, - MakeBinaryAdd(filtered->shape().element_type(), module_), + MakeBinaryAdd(filtered->shape().element_type(), module_), {}, NewChannel()); ar->set_sharding(HloSharding::Replicate()); SetPartitionedHlo(hlo, [&]() { @@ -2227,31 +2397,47 @@ Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) { /*parameter_number=*/0, token->shape(), "infeed_token_param")); auto infeed = branch_b.AddInstruction(HloInstruction::CreateInfeed( per_branch_partitioned_shapes[i], param, hlo->infeed_config())); - branches[i] = module_->AddEmbeddedComputation(branch_b.Build(infeed)); if (!ShapeUtil::Compatible(per_branch_partitioned_shapes[i], shard_shape)) { - TF_ASSIGN_OR_RETURN( - auto padded, - branches[i]->DeepCopyInstructionWithCustomCopier( - infeed, [&](HloInstruction* leaf, const ShapeIndex& leaf_index, - HloComputation* comp) { - // Index {1} corresponds to the token. - if (leaf_index.empty() || leaf_index[0] != 0) { - return leaf; - } - ShapeIndexView subindex(leaf_index, 1); - if (ShapeUtil::Compatible( - ShapeUtil::GetSubshape(per_branch_partitioned_shapes[i], - subindex), - ShapeUtil::GetSubshape(shard_shape, subindex))) { - return leaf; - } - return PadToShape(leaf, - ShapeUtil::GetSubshape(shard_shape, subindex), - nullptr, comp); - })); - branches[i]->set_root_instruction(padded, - /*accept_different_shape=*/true); + std::function + pad_infeed = [&](const ShapeIndex& index, + HloInstruction* infeed_element) -> HloInstruction* { + if (index == ShapeIndex({1})) { + // Token. + return infeed_element; + } + const Shape& element_shape = + ShapeUtil::GetSubshape(infeed->shape(), index); + if (element_shape.IsTuple() && element_shape.tuple_shapes_size() > 0) { + std::vector padded_elements( + element_shape.tuple_shapes_size()); + for (int64 i = 0; i < padded_elements.size(); ++i) { + auto sub_index = index; + sub_index.push_back(i); + padded_elements[i] = pad_infeed( + sub_index, + branch_b.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(element_shape, {i}), infeed_element, + i))); + } + return branch_b.AddInstruction( + HloInstruction::CreateTuple(padded_elements)); + } + const Shape& pad_shape = + ShapeUtil::GetSubshape(shard_shape, ShapeIndexView(index, 1)); + if (ShapeUtil::Compatible(element_shape, pad_shape)) { + return infeed_element; + } + if (element_shape.IsArray()) { + CHECK(pad_shape.IsArray()); + return PadToShape(infeed_element, pad_shape, &branch_b); + } + CHECK(element_shape.IsTuple()); + CHECK(element_shape.tuple_shapes().empty()); + return CreateZero(pad_shape, &branch_b); + }; + pad_infeed({}, infeed); } + branches[i] = module_->AddEmbeddedComputation(branch_b.Build()); } SetPartitionedHlo(hlo, [&]() { return b_.AddInstruction(HloInstruction::CreateConditional( @@ -2374,17 +2560,6 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { if (reduce_sharded_dimension && input_count > 1) { return DefaultAction(hlo); } - - // Currently we only support reducing all or none of the sharded - // dimensions. - if (reduce_sharded_dimension) { - for (int64 i = 0; i < inputs[0].base_shape().rank(); ++i) { - if (inputs[0].sharding().tile_assignment().dim(i) > 1 && - absl::c_count(hlo->dimensions(), i) == 0) { - return DefaultAction(hlo); - } - } - } } std::vector new_operand_shapes(input_count * 2); @@ -2397,7 +2572,6 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { auto reduce_shape, ShapeInference::InferReduceShape(new_operand_shapes, hlo->dimensions(), hlo->to_apply()->ComputeProgramShape())); - *reduce_shape.mutable_layout() = hlo->shape().layout(); std::vector input_hlos(input_count); for (int64 i = 0; i < input_count; ++i) { @@ -2408,36 +2582,30 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { local_reduce->set_metadata(hlo->metadata()); SetPartitionedHlo(hlo, [&]() { - HloInstruction* reduce; + HloInstruction* reduce = local_reduce; if (reduce_sharded_dimension) { CHECK(local_reduce->shape().IsArray()); - reduce = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, local_reduce, hlo->to_apply(), NewChannel()); - reduce->set_sharding(HloSharding::Replicate()); - } else { - reduce = local_reduce; - if (inputs[0].sharding().IsTileMaximal()) { - reduce->set_sharding(inputs[0].sharding()); - } else { - // Remove tile assignment dimensions that are reduced. - std::vector tile_dimensions; - for (int64 i = 0; i < input_hlos[0]->shape().rank(); ++i) { - if (absl::c_count(hlo->dimensions(), i) == 0) { - tile_dimensions.push_back( - inputs[0].sharding().tile_assignment().dim(i)); - } + std::vector preserved_dims; + for (int64 i = 0; i < inputs[0].base_shape().rank(); ++i) { + if (!absl::c_linear_search(hlo->dimensions(), i)) { + preserved_dims.push_back(i); } - Array new_tile = inputs[0].sharding().tile_assignment(); - new_tile.Reshape(tile_dimensions); - auto sharding = HloSharding::Tile(new_tile); - if (input_count > 1) { - std::vector tuple(input_count, sharding); - sharding = HloSharding::Tuple(hlo->shape(), tuple); - } - reduce->set_sharding(sharding); } + if (inputs[0].sharding().ReplicateOnLastTileDim()) { + preserved_dims.push_back(inputs[0].base_shape().rank()); + } + auto grouped = GroupShardingOnDims(inputs[0].sharding(), preserved_dims); + auto grouped_state = CreatePerGroupPartitioningState( + inputs[0].state(), grouped.device_groups, &b_); + reduce = grouped_state.collective_ops_creator + .create_cross_partition_all_reduce( + &b_, local_reduce, hlo->to_apply(), {}, NewChannel()); } - + auto sharding = hlo_sharding_util::RemoveShapeDimensions( + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + inputs[0].sharding(), hlo->dimensions()), + hlo->dimensions()); + reduce->set_sharding(sharding); return PartitionedHlo(reduce, hlo->shape(), MakePartitioningState()) .Reshard(hlo->sharding()) .hlo(); @@ -2846,1774 +3014,6 @@ Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) { return Status::OK(); } -Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs( - HloInstruction* hlo) { - TF_RET_CHECK(hlo->opcode() == HloOpcode::kConvolution); - - auto lhs = GetPartitionedHlo(hlo->operand(0)); - auto rhs = GetPartitionedHlo(hlo->operand(1)); - TF_RET_CHECK(!lhs.sharding().IsTileMaximal() && - !rhs.sharding().IsTileMaximal()); - - const auto& dnums = hlo->convolution_dimension_numbers(); - - // Check if the operand shardings are aligned. Also we currently don't - // support partitioning non-spatial dimensions. - std::vector rhs_to_lhs_indices(hlo->shape().rank()); - rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = - dnums.input_batch_dimension(); - rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = - dnums.input_feature_dimension(); - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = - dnums.input_spatial_dimensions(i); - } - std::vector lhs_to_rhs_indices(hlo->shape().rank()); - for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { - lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; - } - - Window window = hlo->window(); - std::vector reversed_rhs_dims; - for (int64 i = 0; i < window.dimensions_size(); ++i) { - if (window.dimensions(i).window_reversal()) { - reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i)); - } - } - if (!reversed_rhs_dims.empty()) { - // Make the reversed dims left-padded to prepare for window reversal. - auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims); - if (left_padded_rhs == nullptr) { - return DefaultAction(hlo); - } - left_padded_rhs->set_sharding(rhs.sharding()); - rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state()); - } - // Consider window reversal when resharding RHS or LHS. Note: this will not - // reverse the data in the shard. We use window reversal to do that. - auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding( - hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices), - reversed_rhs_dims); - auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding( - hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims), - lhs_to_rhs_indices); - - auto unsupported_sharding = [&](const HloSharding& lhs_sharding, - const HloSharding& rhs_sharding) { - return lhs_sharding.tile_assignment().dim(dnums.input_batch_dimension()) != - 1 || - rhs_sharding.tile_assignment().dim( - dnums.kernel_output_feature_dimension()) != 1; - }; - - auto zero = b_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()))); - if (ShapeSizeInBytes(lhs.base_shape()) < ShapeSizeInBytes(rhs.base_shape())) { - if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { - return DefaultAction(hlo); - } - lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); - rhs = rhs.PadWithValue(zero, reversed_rhs_dims); - } else { - if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { - return DefaultAction(hlo); - } - lhs = lhs.PadWithValue(zero); - rhs = - rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims); - } - - // Reshard LHS by exchanging halo such that each shard computes the partial - // sum of the full shape result, and add AllReduce. - // - // The size of halo on each dimension can be calculated from the projection - // onto the LHS that each RHS shard i needs to read. RHS and LHS below refers - // to the shard size of RHS and LHS, WC is the number of windows, and D is the - // window dilation. - // - // * offset(i): RHS * D * i - low_padding - // * limit(i): {(RHS - 1) * D + 1} * (i + 1) + (WC - 1) * stride - low_padding - // - // Since shard i has LHS of range [i * LHS, (i + 1) * LHS) - // * left-halo: i * LHS - offset(i) - // = (LHS - RHS) * i + low_padding - // * right-halo: limit(i) - (i + 1) * LHS - // = [{(RHS - 1) * D + 1} - LHS] * (i + 1) + (WC - 1) * stride - low_padding - std::vector shard_counts(dnums.input_spatial_dimensions_size()); - std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); - std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - int64 lhs_dimension = dnums.input_spatial_dimensions(i); - int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); - int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension); - auto wd = window.dimensions(i); - if (wd.base_dilation() != 1) { - return DefaultAction(hlo); - } - - int64 lhs_shard_size = - CeilOfRatio(lhs.base_shape().dimensions(lhs_dimension), shard_count); - int64 rhs_shard_size = - CeilOfRatio(rhs.base_shape().dimensions(rhs_dimension), shard_count); - shard_counts[i] = shard_count; - lhs_shard_sizes[i] = lhs_shard_size; - rhs_shard_sizes[i] = rhs_shard_size; - } - - std::vector left_halo_size_functions(hlo->shape().rank()); - std::vector right_halo_size_functions(hlo->shape().rank()); - Window new_window = window; - - auto partition_ordinals = - MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_); - HloInstruction* lhs_with_halo = lhs.hlo(); - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - int64 lhs_dimension = dnums.input_spatial_dimensions(i); - int64 lhs_shard_size = lhs_shard_sizes[i]; - int64 rhs_shard_size = rhs_shard_sizes[i]; - - if (shard_counts[i] == 1) { - continue; - } - - // Calculate the left and right halo sizes as described in the comments - // above. - auto wd = window.dimensions(i); - int64 padding_low = wd.padding_low(); - int64 padding_high = wd.padding_high(); - int64 base = lhs.base_shape().dimensions(lhs_dimension); - int64 window_count = 1 + (padding_low + padding_high + base - - (1 + (wd.size() - 1) * wd.window_dilation())) / - wd.stride(); - int64 rhs_shard_size_dilated = - (rhs_shard_size - 1) * wd.window_dilation() + 1; - - left_halo_size_functions[lhs_dimension] = - OffsetCalculation(MultiplyAddDivideOffsetCalculation( - lhs_shard_size - rhs_shard_size * wd.window_dilation(), padding_low, - 1)); - right_halo_size_functions[lhs_dimension] = - OffsetCalculation(MultiplyAddDivideOffsetCalculation( - rhs_shard_size_dilated - lhs_shard_size, - rhs_shard_size_dilated - lhs_shard_size + - wd.stride() * (window_count - 1) - padding_low, - 1)); - - // Exchange halo and concatenate. - int64 dim = dnums.input_spatial_dimensions(i); - int64 explicit_left_padding_on_full_shape = padding_low; - int64 shard_size_with_halo = - wd.stride() * (window_count - 1) + rhs_shard_size_dilated; - - new_window.mutable_dimensions(i)->set_padding_low(0); - new_window.mutable_dimensions(i)->set_padding_high(0); - new_window.mutable_dimensions(i)->set_size(rhs_shard_size); - - // offset_on_padded_shape and padded_full_shape_size are needed only if - // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). - // Since the default value for both the collective-permute is zero and - // also we call PadWithValue() on both operands at the beginning, we - // don't need to mask here. - // - // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls - // if it's always safe. - auto offset_on_padded_shape = - OffsetCalculation(MultiplyAddDivideOffsetCalculation()); - int64 padded_full_shape_size = 0; - auto concat = ExchangeHaloAndGetValidData( - lhs_with_halo, lhs.base_shape(), left_halo_size_functions[dim], - right_halo_size_functions[dim], explicit_left_padding_on_full_shape, - padded_full_shape_size, shard_size_with_halo, dim, lhs.sharding(), - offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), zero, - partition_ordinals[dim], collective_ops_creator_, next_channel_id_, &b_, - /*mask_invalid_region=*/false); - if (!concat) { - return DefaultAction(hlo); - } - lhs_with_halo = *concat; - } - - SetPartitionedHlo(hlo, [&]() { - auto conv = b_.AddInstruction(HloInstruction::CreateConvolve( - hlo->shape(), lhs_with_halo, rhs.hlo(), hlo->feature_group_count(), - hlo->batch_group_count(), new_window, - hlo->convolution_dimension_numbers(), hlo->precision_config())); - auto ar = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), - NewChannel()); - ar->set_sharding(HloSharding::Replicate()); - return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return Status::OK(); -} - -Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { - auto dot_dnums = dot_as_convolution_util::ParseDotGeneralFromConvolution(hlo); - if (dot_dnums) { - // Use HandleDotHelper() for convs that are actually einsums. - spmd::DotGeneralDimsMapping mapping; - for (const auto& dims : dot_dnums->batch_dims) { - mapping.batch_dims.emplace_back(); - mapping.batch_dims.back().lhs = dims.lhs; - mapping.batch_dims.back().rhs = dims.rhs; - mapping.batch_dims.back().output = dims.output; - } - for (const auto& dims : dot_dnums->contracting_dims) { - mapping.contracting_dims.emplace_back(); - mapping.contracting_dims.back().lhs = dims.lhs; - mapping.contracting_dims.back().rhs = dims.rhs; - mapping.contracting_dims.back().output = dims.output; - } - for (const auto& dims : dot_dnums->lhs_non_contracting_dims) { - mapping.lhs_non_contracting_dims.emplace_back(); - mapping.lhs_non_contracting_dims.back().lhs = dims.lhs; - mapping.lhs_non_contracting_dims.back().rhs = dims.rhs; - mapping.lhs_non_contracting_dims.back().output = dims.output; - } - for (const auto& dims : dot_dnums->rhs_non_contracting_dims) { - mapping.rhs_non_contracting_dims.emplace_back(); - mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; - mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; - mapping.rhs_non_contracting_dims.back().output = dims.output; - } - auto create_sharded_conv = - [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, - spmd::SpmdBuilder* b) -> StatusOr { - TF_ASSIGN_OR_RETURN( - auto sharded_conv, - dot_as_convolution_util::CreateShardedConvForDotGeneralConvolution( - *hlo, *dot_dnums, lhs_hlo, rhs_hlo)); - return b->AddInstruction(std::move(sharded_conv)); - }; - return HandleDotHelper(hlo, mapping, create_sharded_conv); - } - - auto lhs = GetPartitionedHlo(hlo->operand(0)); - auto rhs = GetPartitionedHlo(hlo->operand(1)); - const HloSharding& sharding = hlo->sharding(); - const auto& dnums = hlo->convolution_dimension_numbers(); - std::vector rhs_to_lhs_indices(hlo->shape().rank()); - rhs_to_lhs_indices[dnums.kernel_output_feature_dimension()] = - dnums.input_batch_dimension(); - rhs_to_lhs_indices[dnums.kernel_input_feature_dimension()] = - dnums.input_feature_dimension(); - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - rhs_to_lhs_indices[dnums.kernel_spatial_dimensions(i)] = - dnums.input_spatial_dimensions(i); - } - std::vector lhs_to_rhs_indices(hlo->shape().rank()); - for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) { - lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i; - } - auto aligned_rhs_sharding = - hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices); - auto aligned_lhs_sharding = - hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices); - - // Handling cases where both operands' shardings are aligned. We check that - // the LHS batch dimension is not partitioned because it is mapped to the - // output feature dimension in aligned_rhs_sharding, which are not the same - // dimension. - if (!lhs.sharding().IsTileMaximal() && !rhs.sharding().IsTileMaximal()) { - if (options_.conv_halo_exchange_always_on_lhs) { - return HandleConvolutionTiledLhsAndRhs(hlo); - } else { - // Reshard RHS so that each shard computes the partial sum of the full - // shape result, and add AllReduce. See HandleConvolutionTiledLhsAndRhs() - // that reshards LHS. - // - // The size of halo on each dimension can be calculated from the - // projection onto the RHS that shard i needs to read. RHS and LHS below - // refers to the shard size of RHS and LHS, WC is the number of windows, - // and D is the window dilation. - // - // * offset(i): LHS * i + low_padding - (WC - 1) * stride - // * limit(i): LHS * (i + 1) + low_padding - // - // Since shard i has RHS of range [i * RHS * D, (i + 1) * RHS * D) - // * left-halo: i * RHS - offset(i) - // = i * (RHS * D - LHS) + (WC - 1) * stride - low_padding - // * right-halo: limit(i) - (i + 1) * RHS - // = (i + 1) * (LHS - RHS * D) + low_pading - - auto unsupported_sharding = [&](const HloSharding& lhs_sharding, - const HloSharding& rhs_sharding) { - // We currently don't support partitioning input batch or output feature - // dimensions. - return lhs_sharding.tile_assignment().dim( - dnums.input_batch_dimension()) != 1 || - rhs_sharding.tile_assignment().dim( - dnums.kernel_output_feature_dimension()) != 1; - }; - auto zero = b_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()))); - if (ShapeSizeInBytes(lhs.base_shape()) < - ShapeSizeInBytes(rhs.base_shape())) { - if (unsupported_sharding(aligned_lhs_sharding, rhs.sharding())) { - return DefaultAction(hlo); - } - lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero); - rhs = rhs.PadWithValue(zero); - } else { - if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) { - return DefaultAction(hlo); - } - lhs = lhs.PadWithValue(zero); - rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero); - } - - Window window = hlo->window(); - std::vector shard_counts(dnums.input_spatial_dimensions_size()); - std::vector lhs_shard_sizes(dnums.input_spatial_dimensions_size()); - std::vector rhs_shard_sizes(dnums.input_spatial_dimensions_size()); - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - int64 lhs_dimension = dnums.input_spatial_dimensions(i); - int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); - int64 shard_count = rhs.sharding().tile_assignment().dim(rhs_dimension); - auto wd = window.dimensions(i); - if (wd.base_dilation() != 1 || wd.window_reversal()) { - return DefaultAction(hlo); - } - - int64 lhs_shard_size = CeilOfRatio( - lhs.base_shape().dimensions(lhs_dimension), shard_count); - int64 rhs_shard_size = CeilOfRatio( - rhs.base_shape().dimensions(rhs_dimension), shard_count); - shard_counts[i] = shard_count; - lhs_shard_sizes[i] = lhs_shard_size; - rhs_shard_sizes[i] = rhs_shard_size; - } - - std::vector left_halo_size_functions( - hlo->shape().rank()); - std::vector right_halo_size_functions( - hlo->shape().rank()); - Window new_window = window; - - // Data structures needed for Pad and DynamicSlice on LHS if needed. - bool need_dynamic_slice_lhs = false; - auto partition_ordinals = - MakeTiledPartitionOrdinals(lhs.sharding(), partition_id_, &b_); - std::vector zero_padding(hlo->shape().rank()); - PaddingConfig pad_config = - window_util::MakeSymmetricPadding(zero_padding); - auto zero_s32 = b_.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); - std::vector dynamic_slice_start_indices( - hlo->shape().rank(), zero_s32); - Shape dynamic_slice_shape = lhs.hlo()->shape(); - Shape pad_shape = lhs.hlo()->shape(); - - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - int64 lhs_dimension = dnums.input_spatial_dimensions(i); - int64 rhs_dimension = dnums.kernel_spatial_dimensions(i); - int64 lhs_shard_size = lhs_shard_sizes[i]; - int64 rhs_shard_size = rhs_shard_sizes[i]; - - if (shard_counts[i] == 1) { - continue; - } - - // Calculate the left and right halo sizes as described in the comments - // above. It calculcates the halo sizes with dilation, so we apply - // CeilOfRatio({left,right}_halo_size, window_dilation). - auto wd = window.dimensions(i); - int64 padding_low = wd.padding_low(); - int64 padding_high = wd.padding_high(); - int64 base = lhs.base_shape().dimensions(lhs_dimension); - int64 window_count = - 1 + (padding_low + padding_high + base - - (1 + (wd.size() - 1) * wd.window_dilation())) / - wd.stride(); - left_halo_size_functions[rhs_dimension] = - OffsetCalculation(MultiplyAddDivideOffsetCalculation( - rhs_shard_size * wd.window_dilation() - lhs_shard_size, - (window_count - 1) * wd.stride() - padding_low + - wd.window_dilation() - 1, - wd.window_dilation())); - right_halo_size_functions[rhs_dimension] = - OffsetCalculation(MultiplyAddDivideOffsetCalculation( - lhs_shard_size - rhs_shard_size * wd.window_dilation(), - lhs_shard_size - rhs_shard_size * wd.window_dilation() + - padding_low + wd.window_dilation() - 1, - wd.window_dilation())); - - // New RHS window size includes the maximum of both left and right - // halos. - int64 halo_size = left_halo_size_functions[rhs_dimension].MaxInRange( - 1, shard_counts[i]) + - right_halo_size_functions[rhs_dimension].MaxInRange( - 0, shard_counts[i] - 1); - int64 new_window_size = - rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size; - - // The amount of new low padding could be dynamic (e.g., window_dilation - // != 1), which requires pad (to the maximum) and dynamic slice on LHS. - // - // If we consider the first window, the offset of the dilated RHS that - // aligns with the first valid LHS element for shard i is 'padding_low + - // LHS * i'. When the left halo is added to RHS, the offset of the first - // RHS element is (RHS * i - left_halo) * window_dilation. The - // difference between the two values is the amount of padding_low we - // need on LHS. - auto new_padding_low_function = - OffsetCalculation( - HloOpcode::kMultiply, left_halo_size_functions[rhs_dimension], - OffsetCalculation(MultiplyAddDivideOffsetCalculation( - 0, wd.window_dilation(), 1))) - - OffsetCalculation(MultiplyAddDivideOffsetCalculation( - rhs_shard_size * wd.window_dilation() - lhs_shard_size, - -padding_low, 1)); - - int64 new_padding_low_max = - new_padding_low_function.MaxInRange(0, shard_counts[i]); - int64 new_padding_low = new_padding_low_max; - int64 new_padding_high = window_count * wd.stride() + - (new_window_size - 1) * wd.window_dilation() - - new_padding_low - lhs_shard_size; - - // We do pad/dynamic-slice only when the padding is dynamic. - if (!new_padding_low_function.IsConstant()) { - need_dynamic_slice_lhs = true; - new_padding_low = 0; - pad_config.mutable_dimensions(lhs_dimension) - ->set_edge_padding_low(new_padding_low_max); - pad_config.mutable_dimensions(lhs_dimension) - ->set_edge_padding_high(new_padding_low_max); - pad_shape.set_dimensions(lhs_dimension, - lhs_shard_size + 2 * new_padding_low_max); - dynamic_slice_start_indices[lhs_dimension] = - (OffsetCalculation(MultiplyAddDivideOffsetCalculation( - 0, new_padding_low_max, 1)) - - new_padding_low_function) - .Calculate(partition_ordinals[lhs_dimension], &b_); - dynamic_slice_shape.set_dimensions( - lhs_dimension, lhs_shard_size + new_padding_low_max); - } - - // Since the convolution RHS operand size increased with halos, adjust - // the window config accordingly. - new_window.mutable_dimensions(i)->set_padding_low(new_padding_low); - new_window.mutable_dimensions(i)->set_padding_high(new_padding_high); - new_window.mutable_dimensions(i)->set_size( - rhs.hlo()->shape().dimensions(rhs_dimension) + halo_size); - } - - HloInstruction* conv_lhs = lhs.hlo(); - if (need_dynamic_slice_lhs) { - auto pad = b_.AddInstruction( - HloInstruction::CreatePad(pad_shape, lhs.hlo(), zero, pad_config)); - conv_lhs = b_.AddInstruction(HloInstruction::CreateDynamicSlice( - dynamic_slice_shape, pad, dynamic_slice_start_indices, - dynamic_slice_shape.dimensions())); - } - - // Exchange halo and concatenate. - HloInstruction* rhs_with_halo = rhs.hlo(); - for (int i = 0; i < dnums.kernel_spatial_dimensions_size(); ++i) { - int64 dim = dnums.kernel_spatial_dimensions(i); - int64 explicit_left_padding_on_full_shape = - left_halo_size_functions[dim].Calculate(0); - int64 shard_size_with_halo = new_window.dimensions(i).size(); - - // offset_on_padded_shape and padded_full_shape_size are needed only if - // we want to mask out-of-range values in ExchangeHaloAndGetValidData(). - // Since the default value for both the collective-permute is zero and - // also we call PadWithValue() on both operands at the beginning, we - // don't need to mask here. - // - // TODO(hyoulkee): Consider removing one of the two PadWithValue() calls - // if it's always safe. - auto offset_on_padded_shape = - OffsetCalculation(MultiplyAddDivideOffsetCalculation( - rhs_shard_sizes[i], explicit_left_padding_on_full_shape, 1)) - - left_halo_size_functions[dim]; - int64 padded_full_shape_size = - offset_on_padded_shape.Calculate(shard_counts[i] - 1) + - new_window.dimensions(i).size(); - auto concat = ExchangeHaloAndGetValidData( - rhs_with_halo, rhs.base_shape(), left_halo_size_functions[dim], - right_halo_size_functions[dim], explicit_left_padding_on_full_shape, - padded_full_shape_size, shard_size_with_halo, dim, rhs.sharding(), - offset_on_padded_shape.Calculate(partition_ordinals[dim], &b_), - zero, partition_ordinals[dim], collective_ops_creator_, - next_channel_id_, &b_, /*mask_invalid_region=*/false); - if (!concat) { - return DefaultAction(hlo); - } - rhs_with_halo = *concat; - } - - SetPartitionedHlo(hlo, [&]() { - auto conv = b_.AddInstruction(HloInstruction::CreateConvolve( - hlo->shape(), conv_lhs, rhs_with_halo, hlo->feature_group_count(), - hlo->batch_group_count(), new_window, dnums, - hlo->precision_config())); - auto ar = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, conv, MakeBinaryAdd(hlo->shape().element_type(), module_), - NewChannel()); - ar->set_sharding(HloSharding::Replicate()); - return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return Status::OK(); - } - } - - if (!sharding.IsTileMaximal()) { - // We don't currently support sharding on output feature dimension. - if (sharding.tile_assignment().dim(dnums.output_feature_dimension()) > 1) { - return DefaultAction(hlo); - } - - // Check if the operand and the output sharding are aligned. - std::vector input_to_output_indices(hlo->shape().rank()); - input_to_output_indices[dnums.input_batch_dimension()] = - dnums.output_batch_dimension(); - input_to_output_indices[dnums.input_feature_dimension()] = - dnums.output_feature_dimension(); - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - input_to_output_indices[dnums.input_spatial_dimensions(i)] = - dnums.output_spatial_dimensions(i); - } - auto target_operand_sharding = - hlo_sharding_util::TransposeSharding(sharding, input_to_output_indices); - lhs = lhs.Reshard(target_operand_sharding); - - // Replicate the RHS. - rhs = rhs.Reshard(HloSharding::Replicate()); - - // Convolution window config does not include batch and feature dimensions, - // whereas ReshardAsWindowedInput() expects the same number of window - // dimensions as the rank of the operand. So add two more trivial - // dimensions. - std::vector ones(hlo->shape().rank(), 1); - auto operand_window = window_util::MakeWindow(ones); - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - *operand_window.mutable_dimensions(dnums.input_spatial_dimensions(i)) = - hlo->window().dimensions(i); - } - - auto zero = b_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()))); - auto resharded_operand_and_window = lhs.ReshardAsWindowedInput( - operand_window, target_operand_sharding, zero); - if (!resharded_operand_and_window.has_value()) { - return DefaultAction(hlo); - } - Window new_window; - for (int64 i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { - *new_window.add_dimensions() = - resharded_operand_and_window->shard_window.dimensions( - dnums.input_spatial_dimensions(i)); - } - TF_ASSIGN_OR_RETURN( - Shape sharded_conv_shape, - ShapeInference::InferConvolveShape( - resharded_operand_and_window->sharded_input->shape(), - rhs.hlo()->shape(), hlo->feature_group_count(), - hlo->batch_group_count(), new_window, dnums)); - auto shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding()); - *sharded_conv_shape.mutable_layout() = shard_shape.layout(); - SetPartitionedHlo(hlo, [&]() { - auto sharded_conv = b_.AddInstruction(HloInstruction::CreateConvolve( - sharded_conv_shape, resharded_operand_and_window->sharded_input, - rhs.hlo(), hlo->feature_group_count(), hlo->batch_group_count(), - new_window, dnums, hlo->precision_config())); - if (!resharded_operand_and_window->dynamic_slice_index_on_output - .has_value()) { - CHECK(ShapeUtil::Compatible(shard_shape, sharded_conv->shape())); - return sharded_conv; - } - return b_.AddInstruction(HloInstruction::CreateDynamicSlice( - shard_shape, sharded_conv, - *resharded_operand_and_window->dynamic_slice_index_on_output, - shard_shape.dimensions())); - }); - return Status::OK(); - } - return DefaultAction(hlo); -} - -Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { - DotGeneralDimsMapping mapping; - const auto& dnums = hlo->dot_dimension_numbers(); - int64 next_output_dim = 0; - for (int64 i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { - mapping.batch_dims.emplace_back(); - mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i); - mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i); - mapping.batch_dims.back().output = next_output_dim++; - } - for (int64 i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { - mapping.contracting_dims.emplace_back(); - mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i); - mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i); - mapping.contracting_dims.back().output = -1; - } - for (int64 i = 0; i < hlo->operand(0)->shape().rank(); ++i) { - if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) || - absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) { - continue; - } - mapping.lhs_non_contracting_dims.emplace_back(); - mapping.lhs_non_contracting_dims.back().lhs = i; - mapping.lhs_non_contracting_dims.back().rhs = -1; - mapping.lhs_non_contracting_dims.back().output = next_output_dim++; - } - for (int64 i = 0; i < hlo->operand(1)->shape().rank(); ++i) { - if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) || - absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) { - continue; - } - mapping.rhs_non_contracting_dims.emplace_back(); - mapping.rhs_non_contracting_dims.back().lhs = -1; - mapping.rhs_non_contracting_dims.back().rhs = i; - mapping.rhs_non_contracting_dims.back().output = next_output_dim++; - } - auto create_sharded_dot = [&](HloInstruction* l, HloInstruction* r, - SpmdBuilder* b) -> StatusOr { - TF_ASSIGN_OR_RETURN( - auto sharded_dot_shape, - ShapeInference::InferDotOpShape(l->shape(), r->shape(), - hlo->dot_dimension_numbers())); - return b->AddInstruction(HloInstruction::CreateDot( - sharded_dot_shape, l, r, hlo->dot_dimension_numbers(), - hlo->precision_config())); - }; - return HandleDotHelper(hlo, mapping, create_sharded_dot); -} - -Status SpmdPartitioningVisitor::HandleDotHelper( - HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, - const std::function( - HloInstruction*, HloInstruction*, SpmdBuilder*)>& create_sharded_dot) { - const HloSharding& lhs_sharding = hlo->operand(0)->sharding(); - const HloSharding& rhs_sharding = hlo->operand(1)->sharding(); - - // Similar to hlo_sharding_util::TransposeSharding(), but allows - // removing/adding non-partitioned dimensions. - auto transpose_sharding = - [&](const HloSharding& source, absl::Span src_to_tgt, - absl::Span tgt_to_src) -> absl::optional { - if (source.IsTileMaximal()) { - return source; - } - std::vector tgt_dims_skipping_new(tgt_to_src.size(), -1); - int64 skipped_tgt_dims = 0; - for (int64 i = 0; i < tgt_to_src.size(); ++i) { - if (tgt_to_src[i] < 0) { - skipped_tgt_dims++; - } else { - tgt_dims_skipping_new[i] = i - skipped_tgt_dims; - } - } - int64 skipped_src_dims = absl::c_count(src_to_tgt, -1); - std::vector perm(src_to_tgt.size()); - for (int64 i = 0; i < src_to_tgt.size(); ++i) { - if (src_to_tgt[i] < 0) { - if (source.tile_assignment().dim(i) > 1) { - return absl::nullopt; - } - perm[src_to_tgt.size() - skipped_src_dims] = i; - skipped_src_dims--; - } else { - perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i; - } - } - auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); - if (skipped_tgt_dims == 0) { - return tgt_sharding; - } - auto reshape_tiles = tgt_sharding.tile_assignment(); - std::vector tgt_tiles(tgt_to_src.size(), 1); - for (int64 i = 0; i < tgt_tiles.size(); ++i) { - if (tgt_to_src[i] >= 0) { - tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]); - } - } - reshape_tiles.Reshape(tgt_tiles); - return HloSharding::Tile(reshape_tiles); - }; - - std::vector lhs_to_rhs_indices(hlo->operand(0)->shape().rank(), -1); - std::vector lhs_to_output_indices(hlo->operand(0)->shape().rank(), -1); - std::vector rhs_to_lhs_indices(hlo->operand(1)->shape().rank(), -1); - std::vector rhs_to_output_indices(hlo->operand(1)->shape().rank(), -1); - std::vector output_to_lhs_indices(hlo->shape().rank(), -1); - std::vector output_to_rhs_indices(hlo->shape().rank(), -1); - auto populate_indices_mapping = - [&](const DotGeneralDimsMapping::DimsMapping& mapping) { - if (mapping.lhs >= 0) { - lhs_to_rhs_indices[mapping.lhs] = mapping.rhs; - lhs_to_output_indices[mapping.lhs] = mapping.output; - } - if (mapping.rhs >= 0) { - rhs_to_lhs_indices[mapping.rhs] = mapping.lhs; - rhs_to_output_indices[mapping.rhs] = mapping.output; - } - if (mapping.output >= 0) { - output_to_lhs_indices[mapping.output] = mapping.lhs; - output_to_rhs_indices[mapping.output] = mapping.rhs; - } - }; - for (const auto& mapping : dims_mapping.batch_dims) { - populate_indices_mapping(mapping); - } - for (const auto& mapping : dims_mapping.contracting_dims) { - populate_indices_mapping(mapping); - } - for (const auto& mapping : dims_mapping.lhs_non_contracting_dims) { - populate_indices_mapping(mapping); - } - for (const auto& mapping : dims_mapping.rhs_non_contracting_dims) { - populate_indices_mapping(mapping); - } - auto lhs_sharding_transposed_to_match_rhs = - transpose_sharding(lhs_sharding, lhs_to_rhs_indices, rhs_to_lhs_indices); - auto rhs_sharding_transposed_to_match_lhs = - transpose_sharding(rhs_sharding, rhs_to_lhs_indices, lhs_to_rhs_indices); - auto lhs_sharding_transposed_to_match_output = transpose_sharding( - lhs_sharding, lhs_to_output_indices, output_to_lhs_indices); - auto rhs_sharding_transposed_to_match_output = transpose_sharding( - rhs_sharding, rhs_to_output_indices, output_to_rhs_indices); - auto output_sharding_transposed_to_match_lhs = transpose_sharding( - hlo->sharding(), output_to_lhs_indices, lhs_to_output_indices); - auto output_sharding_transposed_to_match_rhs = transpose_sharding( - hlo->sharding(), output_to_rhs_indices, rhs_to_output_indices); - - // lhs_rhs_or_output: 0 lhs, 1 rhs, 2 output. - auto get_partitions_for_dims = - [&](const HloSharding& sharding, - absl::Span dims, - int lhs_rhs_or_output) { - int64 partitions = 1; - if (sharding.IsTileMaximal()) { - return partitions; - } - for (const auto& dim : dims) { - if (lhs_rhs_or_output == 0) { - partitions *= sharding.tile_assignment().dim(dim.lhs); - } else if (lhs_rhs_or_output == 1) { - partitions *= sharding.tile_assignment().dim(dim.rhs); - } else { - CHECK_EQ(lhs_rhs_or_output, 2); - partitions *= sharding.tile_assignment().dim(dim.output); - } - } - return partitions; - }; - const int64 lhs_batch_partitions = - get_partitions_for_dims(lhs_sharding, dims_mapping.batch_dims, 0); - const int64 rhs_batch_partitions = - get_partitions_for_dims(rhs_sharding, dims_mapping.batch_dims, 1); - const int64 output_batch_partitions = - get_partitions_for_dims(hlo->sharding(), dims_mapping.batch_dims, 2); - const int64 lhs_contracting_partitions = - get_partitions_for_dims(lhs_sharding, dims_mapping.contracting_dims, 0); - const int64 rhs_contracting_partitions = - get_partitions_for_dims(rhs_sharding, dims_mapping.contracting_dims, 1); - const int64 lhs_non_contracting_partitions = get_partitions_for_dims( - lhs_sharding, dims_mapping.lhs_non_contracting_dims, 0); - const int64 rhs_non_contracting_partitions = get_partitions_for_dims( - rhs_sharding, dims_mapping.rhs_non_contracting_dims, 1); - const int64 output_lhs_non_contracting_partitions = get_partitions_for_dims( - hlo->sharding(), dims_mapping.lhs_non_contracting_dims, 2); - const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( - hlo->sharding(), dims_mapping.rhs_non_contracting_dims, 2); - - auto& lhs = GetPartitionedHlo(hlo->operand(0)); - auto& rhs = GetPartitionedHlo(hlo->operand(1)); - // LHS and RHS are partitioned the same way and only partitioned in batch - // dimensions. - if (lhs_batch_partitions == rhs_batch_partitions && - rhs_batch_partitions == num_partitions_ && - lhs_sharding_transposed_to_match_rhs == rhs_sharding) { - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { - dot->set_sharding(*lhs_sharding_transposed_to_match_output); - return PartitionedHlo(dot, hlo->shape(), MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return Status::OK(); - } - - // Try emit batch-partitioned einsum with one operand resharded. Returns - // whether the attempt succeeds. If may_reshard_with_allreduce is false, - // reshard must be done using all-to-all; otherwise this attempt fails. - auto try_emit_output_batch_partitioned_einsum_with_reshard = - [&](bool may_reshard_with_allreduce) -> StatusOr { - // LHS and output are batch partitioned in the same way. - if (lhs_batch_partitions == num_partitions_ && - output_batch_partitions == num_partitions_ && - lhs_sharding_transposed_to_match_output == hlo->sharding()) { - if (!may_reshard_with_allreduce && - !CanReshardWithAllToAll(rhs.sharding(), - *lhs_sharding_transposed_to_match_rhs)) { - return false; - } - auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN( - auto dot, create_sharded_dot(lhs.hlo(), resharded_rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return true; - } - // RHS and output are batch partitioned in the same way. - if (rhs_batch_partitions == num_partitions_ && - output_batch_partitions == num_partitions_ && - rhs_sharding_transposed_to_match_output == hlo->sharding()) { - if (!may_reshard_with_allreduce && - !CanReshardWithAllToAll(lhs.sharding(), - *rhs_sharding_transposed_to_match_lhs)) { - return false; - } - auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs); - TF_ASSIGN_OR_RETURN( - auto dot, create_sharded_dot(resharded_lhs.hlo(), rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return true; - } - return false; - }; - - { - // Try batch-parallel by resharding one operand, and not using all-reduce. - TF_ASSIGN_OR_RETURN( - bool emitted, - try_emit_output_batch_partitioned_einsum_with_reshard(false)); - if (emitted) { - return Status::OK(); - } - } - - // Try to emit windowed DotGeneral when one operand is partitioned in the same - // way as the output along non-contracting dimensions, but the other operand - // is tiled in other dimensions. - auto emit_windowed_dot_general = [&](int64 matching_operand, - int64 windowing_operand, - bool windowed_at_contracting_dims, - bool windowed_at_batch_dims) { - CHECK_EQ(matching_operand + windowing_operand, 1); - CHECK(!windowed_at_batch_dims || !windowed_at_contracting_dims); - auto unpadded_result_buffer_shape = - MakePartitionedShape(hlo->shape(), hlo->sharding()); - auto padded_result_buffer_shape = unpadded_result_buffer_shape; - // For windowing at batch/non-contracting dims, we produce the result one - // partition at a time, so we need to pad the shape in case of uneven - // partitioning in order to make dynamic-update-slice in-bound. - if (!windowed_at_contracting_dims) { - padded_result_buffer_shape = GetPaddedShapeForUnevenPartitioning( - padded_result_buffer_shape, - windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output - : *rhs_sharding_transposed_to_match_output); - } - // Mask the padding area of the windowed operand with zero if there is - // uneven partitioning. - if (windowed_at_contracting_dims) { - auto& to_mask = windowing_operand == 0 ? lhs : rhs; - to_mask = - to_mask.PadWithValue(b_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type())))); - } - auto result_buffer = CreateZero(padded_result_buffer_shape, &b_); - auto iteration = b_.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); - - // Create a while loop that computes one window per iteration. During each - // iteration, each partition sends its input window to its neighbor using - // collective-permute for the next iteration. - SpmdBuilder body_b("windowed_dot_general_body", visiting_hlo_); - auto param = body_b.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, - ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), - result_buffer->shape(), iteration->shape()}), - "param")); - auto l = body_b.AddInstruction( - HloInstruction::CreateGetTupleElement(lhs.hlo()->shape(), param, 0)); - auto r = body_b.AddInstruction( - HloInstruction::CreateGetTupleElement(rhs.hlo()->shape(), param, 1)); - auto o = body_b.AddInstruction(HloInstruction::CreateGetTupleElement( - result_buffer->shape(), param, 2)); - auto i = body_b.AddInstruction( - HloInstruction::CreateGetTupleElement(iteration->shape(), param, 3)); - - auto partition_id = collective_ops_creator_.create_partition_id(&body_b); - auto data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( - i->shape(), HloOpcode::kAdd, i, partition_id)); - auto partition_count = body_b.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(num_partitions_))); - data_partition_id = body_b.AddInstruction(HloInstruction::CreateBinary( - i->shape(), HloOpcode::kRemainder, data_partition_id, partition_count)); - auto dot_lhs = l; - auto dot_rhs = r; - if (windowed_at_contracting_dims || windowed_at_batch_dims) { - // Slice the matching operand according to the partitioned contracting - // dimensions on the windowed operand. We do this by treating the matching - // operand as replicated, and resharding it to match the windowed operand. - auto slice_operand = matching_operand == 0 ? l : r; - slice_operand->set_sharding(HloSharding::Replicate()); - auto state = MakePartitioningState(); - state.b = &body_b; - state.partition_id = data_partition_id; - auto slice = PartitionedHlo(slice_operand, slice_operand->shape(), state) - .Reshard(windowing_operand == 0 - ? *lhs_sharding_transposed_to_match_rhs - : *rhs_sharding_transposed_to_match_lhs) - .hlo(); - slice_operand->clear_sharding(); - if (matching_operand == 0) { - dot_lhs = slice; - } else { - dot_rhs = slice; - } - } - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(dot_lhs, dot_rhs, &body_b)); - if (windowed_at_contracting_dims) { - // Accumulate the partial output to the result buffer. - o = body_b.AddInstruction( - HloInstruction::CreateBinary(o->shape(), HloOpcode::kAdd, o, dot)); - } else { - // The windowing operand is partitioned along batch/non-contracting - // dimensions, so we need a dynamic-update-slice to save the partial - // output in the result buffer. - auto offsets = MakePartitionOffsets( - o->shape(), - windowing_operand == 0 ? *lhs_sharding_transposed_to_match_output - : *rhs_sharding_transposed_to_match_output, - data_partition_id, &body_b); - o = body_b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - o->shape(), o, dot, offsets)); - } - - // ++i - i = body_b.AddInstruction(HloInstruction::CreateBinary( - i->shape(), HloOpcode::kAdd, i, - body_b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))))); - auto has_more = body_b.AddInstruction(HloInstruction::CreateCompare( - ShapeUtil::MakeShape(PRED, {}), i, - body_b.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(num_partitions_))), - ComparisonDirection::kLt)); - // Collective-permute for the next window. We don't need it for the last - // iteration, so we use a conditional around the collective-permute. - HloInstruction* conditional; - { - SpmdBuilder cp_b("window_collective_permute", visiting_hlo_); - { - auto p = cp_b.AddInstruction(HloInstruction::CreateParameter( - 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); - std::vector> sd_pairs(num_partitions_); - for (int64 source = 0; source < num_partitions_; ++source) { - // 0 -> n-1, 1 -> 0, 2 -> 1, ... - sd_pairs[source] = {source, - (source - 1 + num_partitions_) % num_partitions_}; - } - collective_ops_creator_.create_cross_partition_collective_permute( - &cp_b, p, sd_pairs, (*next_channel_id_)++); - } - SpmdBuilder ncp_b("last_iteration_noop", visiting_hlo_); - { - ncp_b.AddInstruction(HloInstruction::CreateParameter( - 0, windowing_operand == 0 ? l->shape() : r->shape(), "window")); - } - conditional = body_b.AddInstruction(HloInstruction::CreateConditional( - windowing_operand == 0 ? l->shape() : r->shape(), has_more, - windowing_operand == 0 ? l : r, - module_->AddEmbeddedComputation(cp_b.Build()), - windowing_operand == 0 ? l : r, - module_->AddEmbeddedComputation(ncp_b.Build()))); - } - if (windowing_operand == 0) { - l = conditional; - } else { - r = conditional; - } - body_b.AddInstruction(HloInstruction::CreateTuple({l, r, o, i})); - - SpmdBuilder cond_b("windowed_dot_general_cond", visiting_hlo_); - auto cond_param = cond_b.AddInstruction(HloInstruction::CreateParameter( - /*parameter_number=*/0, - ShapeUtil::MakeTupleShape({lhs.hlo()->shape(), rhs.hlo()->shape(), - result_buffer->shape(), iteration->shape()}), - "param")); - auto cond_i = cond_b.AddInstruction(HloInstruction::CreateGetTupleElement( - iteration->shape(), cond_param, 3)); - cond_b.AddInstruction(HloInstruction::CreateCompare( - ShapeUtil::MakeShape(PRED, {}), cond_i, - cond_b.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(num_partitions_))), - ComparisonDirection::kLt)); - auto while_loop = b_.AddInstruction(HloInstruction::CreateWhile( - cond_param->shape(), module_->AddEmbeddedComputation(cond_b.Build()), - module_->AddEmbeddedComputation(body_b.Build()), - b_.AddInstruction(HloInstruction::CreateTuple( - {lhs.hlo(), rhs.hlo(), result_buffer, iteration})))); - windowed_dot_general_loops_.push_back({while_loop, windowing_operand, - windowed_at_contracting_dims, - windowed_at_batch_dims}); - SetPartitionedHlo(hlo, [&] { - auto result = b_.AddInstruction(HloInstruction::CreateGetTupleElement( - result_buffer->shape(), while_loop, 2)); - if (!ShapeUtil::Compatible(padded_result_buffer_shape, - unpadded_result_buffer_shape)) { - result = b_.AddInstruction(HloInstruction::CreateSlice( - unpadded_result_buffer_shape, result, - std::vector(padded_result_buffer_shape.rank(), 0), - unpadded_result_buffer_shape.dimensions(), - std::vector(padded_result_buffer_shape.rank(), 1))); - } - return result; - }); - return Status::OK(); - }; - if (output_lhs_non_contracting_partitions == num_partitions_ && - output_sharding_transposed_to_match_lhs == lhs_sharding && - ShapeSizeInBytes(hlo->operand(1)->shape()) >= - options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { - if (rhs_contracting_partitions == num_partitions_) { - return emit_windowed_dot_general(0, 1, true, false); - } - if (rhs_non_contracting_partitions == num_partitions_) { - return emit_windowed_dot_general(0, 1, false, false); - } - if (rhs_batch_partitions == num_partitions_) { - return emit_windowed_dot_general(0, 1, false, true); - } - } - if (output_rhs_non_contracting_partitions == num_partitions_ && - output_sharding_transposed_to_match_rhs == rhs_sharding && - ShapeSizeInBytes(hlo->operand(0)->shape()) >= - options_.threshold_for_windowed_einsum_mib * 1024 * 1024) { - if (lhs_contracting_partitions == num_partitions_) { - return emit_windowed_dot_general(1, 0, true, false); - } - if (lhs_non_contracting_partitions == num_partitions_) { - return emit_windowed_dot_general(1, 0, false, false); - } - if (lhs_batch_partitions == num_partitions_) { - return emit_windowed_dot_general(1, 0, false, true); - } - } - - { - // Try batch-parallel by resharding one operand, and allowing all-reduce. - TF_ASSIGN_OR_RETURN( - bool emitted, - try_emit_output_batch_partitioned_einsum_with_reshard(true)); - if (emitted) { - return Status::OK(); - } - } - - // LHS and RHS have the same partitioned contracting dimensions. - if (lhs_contracting_partitions == rhs_contracting_partitions && - lhs_contracting_partitions == num_partitions_) { - auto zero = b_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()))); - // Pad both sides with zero, since NaN at one side cannot be masked by zero - // on the other side. - if (ShapeSizeInBytes(lhs.base_shape()) < - ShapeSizeInBytes(rhs.base_shape())) { - lhs = - lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); - rhs = rhs.PadWithValue(zero); - } else { - lhs = lhs.PadWithValue(zero); - rhs = - rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); - } - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { - auto ar = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), - NewChannel()); - ar->set_sharding(HloSharding::Replicate()); - return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()) - .Reshard(hlo->sharding()) - .hlo(); - }); - return Status::OK(); - } - - // LHS and output have the same partitioned non-contracting dimensions. - if (lhs_non_contracting_partitions == num_partitions_ && - output_lhs_non_contracting_partitions == num_partitions_ && - lhs_sharding == hlo->sharding()) { - auto rhs_replicated = rhs.Reshard(HloSharding::Replicate()).hlo(); - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs.hlo(), rhs_replicated, &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return Status::OK(); - } - - // RHS and output have the same partitioned non-contracting dimensions. - if (rhs_non_contracting_partitions == num_partitions_ && - output_rhs_non_contracting_partitions == num_partitions_ && - rhs_sharding_transposed_to_match_output == hlo->sharding()) { - auto lhs_replicated = lhs.Reshard(HloSharding::Replicate()).hlo(); - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs_replicated, rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return Status::OK(); - } - - // Output is batch partitioned. - if (output_batch_partitions == num_partitions_) { - auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); - auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), - resharded_rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return Status::OK(); - } - // Output is partitioned along LHS non-contracting dimensions. - if (output_lhs_non_contracting_partitions == num_partitions_) { - auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); - auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); - TF_ASSIGN_OR_RETURN( - auto dot, - create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return Status::OK(); - } - // Output is partitioned along RHS non-contracting dimensions. - if (output_rhs_non_contracting_partitions == num_partitions_) { - auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); - auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), - resharded_rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { return dot; }); - return Status::OK(); - } - - // Returns true if it is beneficial to reshard the operand at `operand_idx` - // across the contracting dimension. - const auto should_partition_contracting_dim = [&](int64 operand_idx) { - if (!hlo->sharding().IsReplicated()) { - return false; - } - - if (operand_idx == 0) { - // If LHS and output are replicated, we compare the cost of all-gather - // on RHS vs all-reduce on the output. - return (rhs_contracting_partitions == num_partitions_) && - lhs.sharding().IsReplicated() && - ShapeUtil::ElementsIn(hlo->operand(1)->shape()) > - ShapeUtil::ElementsIn(hlo->shape()); - } else { - return (lhs_contracting_partitions == num_partitions_) && - rhs.sharding().IsReplicated() && - ShapeUtil::ElementsIn(hlo->operand(0)->shape()) > - ShapeUtil::ElementsIn(hlo->shape()); - } - }; - - // When the output is replicated and one of the operands is partitioned along - // contracting dimension, align the other operand to be partitioned along - // the contracting dimensions. - if (hlo->sharding().IsReplicated() && (should_partition_contracting_dim(0) || - should_partition_contracting_dim(1))) { - auto zero = b_.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(hlo->shape().element_type()))); - if (should_partition_contracting_dim(0)) { - lhs = - lhs.Reshard(*rhs_sharding_transposed_to_match_lhs).PadWithValue(zero); - rhs = rhs.PadWithValue(zero); - } else { - lhs = lhs.PadWithValue(zero); - rhs = - rhs.Reshard(*lhs_sharding_transposed_to_match_rhs).PadWithValue(zero); - } - TF_ASSIGN_OR_RETURN(auto dot, - create_sharded_dot(lhs.hlo(), rhs.hlo(), &b_)); - SetPartitionedHlo(hlo, [&] { - auto ar = collective_ops_creator_.create_cross_partition_all_reduce( - &b_, dot, MakeBinaryAdd(hlo->shape().element_type(), module_), - NewChannel()); - ar->set_sharding(HloSharding::Replicate()); - return PartitionedHlo(ar, hlo->shape(), MakePartitioningState()).hlo(); - }); - return Status::OK(); - } - - return DefaultAction(hlo); -} - -namespace { - -// Finds a cluster of nodes that produce the inputs for `hlo` which only depend -// on small operands, which means the cluster should start with broadcasts, -// constants and iotas. All other internal nodes must be non-side-effecting -// elemntwise ops. Returns the set of nodes, and the small operands. E.g., for -// the following graph, -// -// a -> broadcast -> multiply -// iota ---> add--/ -// constant/ -// -// FindInputNodesIfOnlyDependOnSmallOperands(multiply) will return -// <{broadcast, iota, constant, add, multiply}, [a]>. -std::pair, std::vector> -FindInputNodesIfOnlyDependOnSmallOperands(HloInstruction* hlo) { - std::unordered_set nodes_found; - std::vector new_operands; - std::unordered_set new_operands_set; - std::vector worklist; - worklist.push_back(hlo); - while (!worklist.empty()) { - auto inst = worklist.back(); - worklist.pop_back(); - if (nodes_found.count(inst) > 0) { - continue; - } - if (inst->opcode() == HloOpcode::kBroadcast || - inst->opcode() == HloOpcode::kConstant || - inst->opcode() == HloOpcode::kIota) { - nodes_found.insert(inst); - for (auto o : inst->operands()) { - auto res = new_operands_set.emplace(o); - if (res.second) { - new_operands.push_back(o); - } - } - } else if (inst->IsElementwise() && !inst->HasSideEffectNoRecurse() && - inst->opcode() != HloOpcode::kAllReduce && - absl::c_all_of(inst->operands(), - [inst](const HloInstruction* o) { - return ShapeUtil::CompatibleIgnoringElementType( - o->shape(), inst->shape()); - })) { - nodes_found.insert(inst); - for (auto o : inst->operands()) { - worklist.push_back(o); - } - } else { - nodes_found.clear(); - new_operands.clear(); - break; - } - } - return {std::move(nodes_found), std::move(new_operands)}; -} - -// Moves a cluster of memory-reducing nodes into the windowed dot-general loop -// on contracting dimensions. Such a loop has a dynamic slice on the -// non-windowed operand. If we move the input nodes into the loop, the -// dynamic-slice could be merged with them by later optimization passes, which -// reduces memory. -// -// small_operands small_operands -// | | -// input_nodes loop { | -// | => input_nodes -// loop { | | -// dynamic-slice dynamic-slice -// ... ... -// } } -// -// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice -// with the input nodes. -Status SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( - HloInstruction* loop, int64 non_windowed_operand_index) { - auto input_tuple = loop->mutable_operand(0); - auto old_operand = input_tuple->mutable_operand(non_windowed_operand_index); - auto input_nodes = FindInputNodesIfOnlyDependOnSmallOperands(old_operand); - auto to_sink = std::move(input_nodes.first); - auto new_operands = std::move(input_nodes.second); - if (to_sink.empty()) { - return Status::OK(); - } - auto computation = loop->parent(); - // Replace the old operand with a tuple of the found small operands. - auto new_input_subtuple = - computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); - TF_RETURN_IF_ERROR(input_tuple->ReplaceOperandWithDifferentShape( - non_windowed_operand_index, new_input_subtuple)); - - auto body = loop->while_body(); - auto body_param = body->parameter_instruction(0); - auto old_body_param_users = body_param->users(); - // Update all tuple shapes. - for (auto tuple : std::vector{ - input_tuple, loop, loop->while_condition()->parameter_instruction(0), - body_param, body->root_instruction()}) { - *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), - {non_windowed_operand_index}) = - new_input_subtuple->shape(); - } - // Now update the loop body. - auto new_operand_tuple_inside = - body->AddInstruction(HloInstruction::CreateGetTupleElement( - new_input_subtuple->shape(), body_param, non_windowed_operand_index)); - TF_RETURN_IF_ERROR(body->root_instruction()->ReplaceOperandWithDifferentShape( - non_windowed_operand_index, new_operand_tuple_inside)); - - // Create nodes inside the loop body. - std::vector worklist; - std::unordered_map outside_to_inside; - auto add_users_if_available = [&](HloInstruction* inst) { - for (auto u : inst->users()) { - if (outside_to_inside.count(u) == 0 && to_sink.count(u) > 0 && - absl::c_all_of(u->operands(), [&](const HloInstruction* o) { - return outside_to_inside.count(o) > 0; - })) { - worklist.push_back(u); - } - } - }; - for (int64 i = 0; i < new_operands.size(); ++i) { - outside_to_inside[new_operands[i]] = - body->AddInstruction(HloInstruction::CreateGetTupleElement( - new_operands[i]->shape(), new_operand_tuple_inside, i)); - add_users_if_available(new_operands[i]); - } - // HLOs to sink without operands. - std::vector nullaries_to_sink; - for (auto inst : to_sink) { - if (inst->operand_count() == 0) { - nullaries_to_sink.push_back(inst); - } - } - // Sort nullaries_to_sink to make it deterministic. - absl::c_sort(nullaries_to_sink, - [](const HloInstruction* a, const HloInstruction* b) { - return a->unique_id() < b->unique_id(); - }); - for (auto inst : nullaries_to_sink) { - worklist.push_back(inst); - } - while (!worklist.empty()) { - auto inst = worklist.back(); - worklist.pop_back(); - std::vector inst_new_operands(inst->operand_count()); - for (int64 i = 0; i < inst->operand_count(); ++i) { - inst_new_operands[i] = outside_to_inside[inst->operand(i)]; - } - outside_to_inside[inst] = body->AddInstruction( - inst->CloneWithNewOperands(inst->shape(), inst_new_operands)); - add_users_if_available(inst); - } - TF_RET_CHECK(outside_to_inside.count(old_operand) > 0); - for (auto ou : old_body_param_users) { - if (ou->opcode() == HloOpcode::kGetTupleElement && - ou->tuple_index() == non_windowed_operand_index) { - TF_RETURN_IF_ERROR( - ou->ReplaceAllUsesWith(outside_to_inside[old_operand])); - TF_RETURN_IF_ERROR(body->RemoveInstruction(ou)); - } - } - return Status::OK(); -} - -// Moves a cluster of memory-reducing nodes (with reduce nodes at the end) into -// the windowed dot-general loop on non-contracting dimensions. Such a loop has -// a dynamic-update-slice at the output. If we move the user nodes into the loop -// and before the dynamic-update-slice, the user nodes can operate on smaller -// shapes, which reduces memory. -// -// small_operands small_operands -// | | => | | -// | | loop { loop { | | -// | | conv | broadcast conv -// | | | | | / -// | | dynamic-update-slice | dynamic-slice / -// | | | | | / -// | | } | | multiply----- -// |broadcast / | / -// | | / reduce -// |multiply-- | -// \ | dynamic-update-slice -// reduce } -// -// Later optimization passes (TpuPadSliceMover) will merge the dynamic slice -// with the input nodes (broadcast). -Status MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( - HloInstruction* loop) { - CHECK_EQ(loop->user_count(), 1); - // There should be a single direct user of the while loop, which is the - // gte for element 2, i.e., the dot output. - auto user_gte = loop->users().front(); - CHECK_EQ(user_gte->opcode(), HloOpcode::kGetTupleElement); - CHECK_EQ(user_gte->tuple_index(), 2); - auto computation = loop->parent(); - - // Find the reduce outputs and the input nodes they depend on, if input nodes - // only have small operands. - std::unordered_set to_move; - std::vector new_operands; - std::unordered_set new_operands_set; - std::vector reduce_outputs; - std::vector worklist; - Shape padded_shape = user_gte->shape(); - Shape unpadded_shape = user_gte->shape(); - auto original_output = user_gte; - - if (user_gte->user_count() == 1 && - user_gte->users().back()->opcode() == HloOpcode::kSlice) { - original_output = user_gte->users().back(); - unpadded_shape = original_output->shape(); - } - for (auto u : original_output->users()) { - worklist.push_back(u); - } - to_move.insert(original_output); - while (!worklist.empty()) { - auto inst = worklist.back(); - worklist.pop_back(); - if (to_move.count(inst) > 0) { - continue; - } - // We only support reduces with simple reduction function, since we may need - // to accumulate across iterations manually. - if (inst->opcode() == HloOpcode::kReduce && - inst->to_apply()->instruction_count() == 3 && - inst->to_apply()->num_parameters() == 2 && - inst->to_apply()->root_instruction()->IsElementwise()) { - to_move.insert(inst); - auto other_operand = inst->mutable_operand(1); - auto res = new_operands_set.emplace(other_operand); - if (res.second) { - new_operands.push_back(other_operand); - } - reduce_outputs.push_back(inst); - } else if (inst != computation->root_instruction() && - inst->user_count() > 0 && inst->IsElementwise() && - !inst->HasSideEffectNoRecurse() && - inst->opcode() != HloOpcode::kAllReduce && - absl::c_all_of(inst->operands(), - [inst](const HloInstruction* o) { - return ShapeUtil::CompatibleIgnoringElementType( - o->shape(), inst->shape()); - })) { - // For an elementwise op, we need to make sure that they depend on only - // nodes already in to_move and nodes with small operands. - bool can_include = true; - for (auto operand : inst->operands()) { - if (to_move.count(operand) > 0) { - continue; - } - auto find_result = FindInputNodesIfOnlyDependOnSmallOperands(operand); - if (find_result.first.empty()) { - can_include = false; - break; - } - for (auto n : find_result.first) { - to_move.insert(n); - } - for (auto new_operand : find_result.second) { - auto res = new_operands_set.insert(new_operand); - if (res.second) { - new_operands.push_back(new_operand); - } - } - } - if (!can_include) { - to_move.clear(); - break; - } - to_move.insert(inst); - for (auto u : inst->users()) { - worklist.push_back(u); - } - } else { - to_move.clear(); - break; - } - } - // If nothing is found, to_move could contain only original_output, or cleared - // by the above code. - if (to_move.size() <= 1) { - return Status::OK(); - } - - // We will replace the original loop output with reduce-shape outputs. Create - // the initial buffers before the loop. - for (auto out : reduce_outputs) { - auto padded_out_shape = out->shape(); - int64 operand_dim = 0; - int64 output_dim = 0; - while (output_dim < padded_out_shape.rank()) { - if (absl::c_linear_search(out->dimensions(), operand_dim)) { - // Dimension colapsed. - ++operand_dim; - continue; - } - // Kept dimensions have the same size of the padded shape. - padded_out_shape.set_dimensions(output_dim, - padded_shape.dimensions(operand_dim)); - ++operand_dim; - ++output_dim; - } - auto broadcast = - computation->AddInstruction(HloInstruction::CreateBroadcast( - padded_out_shape, - computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(out->shape().element_type()))), - {})); - new_operands.push_back(broadcast); - } - - auto input_tuple = loop->mutable_operand(0); - // Create the new input subtuple that contains the small operands and the - // reduce-shape result buffers. - auto new_input_subtuple = - computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); - TF_RETURN_IF_ERROR( - input_tuple->ReplaceOperandWithDifferentShape(2, new_input_subtuple)); - auto body = loop->while_body(); - auto body_param = body->parameter_instruction(0); - auto body_root = body->root_instruction(); - CHECK_EQ(body_root->opcode(), HloOpcode::kTuple); - // Update tuple shapes. - for (auto tuple : std::vector{ - input_tuple, loop, loop->while_condition()->parameter_instruction(0), - body_param, body_root}) { - *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {2}) = - new_input_subtuple->shape(); - } - auto new_loop_input = - body->AddInstruction(HloInstruction::CreateGetTupleElement( - new_input_subtuple->shape(), body_param, 2)); - - // Now create the moved nodes inside the loop body. - std::unordered_map outside_to_inside; - worklist.clear(); - auto add_users_if_available = [&](HloInstruction* inst) { - for (auto u : inst->users()) { - if (outside_to_inside.count(u) == 0 && to_move.count(u) > 0 && - absl::c_all_of(u->operands(), [&](const HloInstruction* o) { - return outside_to_inside.count(o) > 0; - })) { - worklist.push_back(u); - } - } - }; - for (int64 i = 0; i < new_operands.size(); ++i) { - outside_to_inside[new_operands[i]] = - body->AddInstruction(HloInstruction::CreateGetTupleElement( - new_operands[i]->shape(), new_loop_input, i)); - add_users_if_available(new_operands[i]); - } - // The elementwise nodes will be created with sliced shape. The original loop - // output corresponds to the dynamic-update-slice's update slice. - auto dus = body_root->mutable_operand(2); - CHECK_EQ(dus->opcode(), HloOpcode::kDynamicUpdateSlice); - outside_to_inside[original_output] = dus->mutable_operand(1); - add_users_if_available(original_output); - std::vector slice_offsets(padded_shape.rank()); - for (int64 i = 0; i < slice_offsets.size(); ++i) { - slice_offsets[i] = dus->mutable_operand(i + 2); - } - auto get_slice = [&](HloInstruction* padded) { - return body->AddInstruction(HloInstruction::CreateDynamicSlice( - ShapeUtil::ChangeElementType(dus->operand(1)->shape(), - padded->shape().element_type()), - padded, slice_offsets, dus->operand(1)->shape().dimensions())); - }; - // Helper functions to create nodes with small operands. - auto add_broadcast = [&](const HloInstruction* broadcast) { - auto padded_operand_shape = broadcast->operand(0)->shape(); - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - padded_operand_shape.set_dimensions( - i, padded_shape.dimensions(broadcast->dimensions(i))); - } - auto padded_operand = PadToShape(outside_to_inside[broadcast->operand(0)], - padded_operand_shape, nullptr, body); - outside_to_inside[broadcast] = - get_slice(body->AddInstruction(broadcast->CloneWithNewOperands( - ShapeUtil::ChangeElementType(padded_shape, - padded_operand_shape.element_type()), - {padded_operand}))); - }; - auto add_iota = [&](const HloInstruction* iota) { - outside_to_inside[iota] = - get_slice(body->AddInstruction(iota->CloneWithNewOperands( - ShapeUtil::ChangeElementType(padded_shape, - iota->shape().element_type()), - {}))); - }; - auto add_constant = [&](const HloInstruction* constant) { - outside_to_inside[constant] = body->AddInstruction(constant->Clone()); - outside_to_inside[constant] = get_slice( - PadToShape(outside_to_inside[constant], - ShapeUtil::ChangeElementType( - padded_shape, constant->shape().element_type()), - nullptr, body)); - }; - while (!worklist.empty()) { - auto inst = worklist.back(); - worklist.pop_back(); - if (outside_to_inside.count(inst) > 0) { - continue; - } - if (inst->opcode() == HloOpcode::kBroadcast) { - add_broadcast(inst); - } else if (inst->opcode() == HloOpcode::kIota) { - add_iota(inst); - } else if (inst->opcode() == HloOpcode::kConstant) { - add_constant(inst); - } else if (inst->opcode() == HloOpcode::kReduce) { - // This is an output, for which we has special handling later. - } else { - std::vector operands_inside(inst->operand_count()); - for (int64 i = 0; i < operands_inside.size(); ++i) { - operands_inside[i] = outside_to_inside[inst->operand(i)]; - } - outside_to_inside[inst] = body->AddInstruction(inst->CloneWithNewOperands( - ShapeUtil::ChangeElementType(dus->operand(1)->shape(), - inst->shape().element_type()), - operands_inside)); - } - add_users_if_available(inst); - } - std::vector new_outputs_inside(new_operands.size()); - for (int64 i = 0; i < new_outputs_inside.size(); ++i) { - new_outputs_inside[i] = outside_to_inside[new_operands[i]]; - } - // Now create the reduce outpus inside of the loop. - for (int64 i = 0; i < reduce_outputs.size(); ++i) { - auto reduce_outside = reduce_outputs[i]; - CHECK_EQ(reduce_outside->opcode(), HloOpcode::kReduce); - int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; - auto last_iter_result = outside_to_inside[new_operands[index_in_operand]]; - auto operand0 = outside_to_inside[reduce_outside->operand(0)]; - auto operand1 = outside_to_inside[reduce_outside->operand(1)]; - TF_ASSIGN_OR_RETURN(auto reduce_shape, - ShapeInference::InferReduceShape( - {&operand0->shape(), &operand1->shape()}, - reduce_outside->dimensions(), - reduce_outside->to_apply()->ComputeProgramShape())); - *reduce_shape.mutable_layout() = reduce_outside->shape().layout(); - std::vector reduce_dus_offsets; - // If any collapsed dimension is windowed, we need to accumulate with last - // iteration's result. If such a dimension has padding, we also need to mask - // off invalid data. - bool needs_accumulate = false; - std::vector dims_to_mask; - for (int64 i = 0; i < slice_offsets.size(); ++i) { - if (absl::c_linear_search(reduce_outside->dimensions(), i)) { - if (reduce_outside->operand(0)->shape().dimensions(i) != - operand0->shape().dimensions(i)) { - needs_accumulate = true; - if (unpadded_shape.dimensions(i) != padded_shape.dimensions(i)) { - dims_to_mask.push_back(i); - } - } - continue; - } - reduce_dus_offsets.push_back(slice_offsets[i]); - } - // Mask off invalid data in collapsed dimensions. - for (int64 dim : dims_to_mask) { - auto iota = body->AddInstruction(HloInstruction::CreateIota( - ShapeUtil::ChangeElementType(operand0->shape(), S32), dim)); - auto add = body->AddInstruction(HloInstruction::CreateBinary( - iota->shape(), HloOpcode::kAdd, iota, - body->AddInstruction(HloInstruction::CreateBroadcast( - iota->shape(), slice_offsets[dim], {})))); - auto limit = body->AddInstruction(HloInstruction::CreateBroadcast( - iota->shape(), - body->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0( - reduce_outside->operand(0)->shape().dimensions(dim)))), - {})); - auto compare = body->AddInstruction(HloInstruction::CreateCompare( - ShapeUtil::ChangeElementType(iota->shape(), PRED), add, limit, - ComparisonDirection::kLt)); - operand0 = body->AddInstruction(HloInstruction::CreateTernary( - operand0->shape(), HloOpcode::kSelect, compare, operand0, - body->AddInstruction(HloInstruction::CreateBroadcast( - operand0->shape(), operand1, {})))); - } - auto output_inside = - body->AddInstruction(reduce_outside->CloneWithNewOperands( - reduce_shape, {operand0, operand1})); - // Accumulate with previous results if needed. - if (needs_accumulate) { - auto input_slice = - body->AddInstruction(HloInstruction::CreateDynamicSlice( - output_inside->shape(), last_iter_result, reduce_dus_offsets, - output_inside->shape().dimensions())); - output_inside = body->AddInstruction(HloInstruction::CreateBinary( - output_inside->shape(), - reduce_outside->to_apply()->root_instruction()->opcode(), - output_inside, input_slice)); - } - // Dynamic-update-slice if needed. - if (!ShapeUtil::Compatible(output_inside->shape(), - last_iter_result->shape())) { - output_inside = - body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - last_iter_result->shape(), last_iter_result, output_inside, - reduce_dus_offsets)); - } - new_outputs_inside[index_in_operand] = output_inside; - } - // Body output. - auto new_output_inside = - body->AddInstruction(HloInstruction::CreateTuple(new_outputs_inside)); - TF_RETURN_IF_ERROR( - body_root->ReplaceOperandWithDifferentShape(2, new_output_inside)); - TF_RETURN_IF_ERROR(body->RemoveInstructionAndUnusedOperands(dus)); - // Replace uses of the reduces outside the loop. - auto new_output_gte = - computation->AddInstruction(HloInstruction::CreateGetTupleElement( - new_output_inside->shape(), loop, 2)); - for (int64 i = 0; i < reduce_outputs.size(); ++i) { - int64 index_in_operand = new_operands.size() - reduce_outputs.size() + i; - auto new_output = - computation->AddInstruction(HloInstruction::CreateGetTupleElement( - new_outputs_inside[index_in_operand]->shape(), new_output_gte, - index_in_operand)); - if (!ShapeUtil::Compatible(new_output->shape(), - reduce_outputs[i]->shape())) { - new_output = computation->AddInstruction(HloInstruction::CreateSlice( - reduce_outputs[i]->shape(), new_output, - std::vector(new_output->shape().rank(), 0), - reduce_outputs[i]->shape().dimensions(), - std::vector(new_output->shape().rank(), 1))); - } - TF_RETURN_IF_ERROR(reduce_outputs[i]->ReplaceAllUsesWith(new_output)); - TF_RETURN_IF_ERROR( - computation->RemoveInstructionAndUnusedOperands(reduce_outputs[i])); - } - return Status::OK(); -} - -} // namespace - -Status SpmdPartitioningVisitor::DoCodeMotionForWindowedDotGeneralLoops( - HloComputation* computation) { - for (auto& loop : windowed_dot_general_loops_) { - if (loop.windowed_in_contracting_dims || loop.windowed_in_batch_dims) { - // We have a dynamic-slice for the non-windowed operand in - // batch/contracting-dim windowed dot-general. So moving the - // broadcast/iota/elementwise ops into the loop could help reduce memory - // via fusion. - TF_RETURN_IF_ERROR( - SinkInputNodesIntoWindowedDotGeneralLoopOnContractingDimensions( - loop.while_loop, 1 - loop.windowed_operand)); - } - if (!loop.windowed_in_contracting_dims) { - // We have a dynamic-update-slice for the output in - // batch/non-contracting-dim windowed dot-general. So moving reduce ops - // into the loop could help reduce memory. - TF_RETURN_IF_ERROR( - MoveUsersIntoWindowedDotGeneralLoopOnNonContractingDimensions( - loop.while_loop)); - } - } - return Status::OK(); -} - StatusOr SpmdPartitioningVisitor::DoPartition( HloComputation* computation, const HloSharding& root_sharding) { VLOG(2) << "Partitioning computation " << computation->name() << " for " @@ -4648,13 +3048,36 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions, [](SpmdBuilder* b) { return b->AddInstruction(HloInstruction::CreatePartitionId()); }, - [num_replicas](SpmdBuilder* b, HloInstruction* operand, - HloComputation* reduction, int64 channel_id) { + [num_replicas, num_partitions]( + SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction, + const std::vector>& partition_subgroups, + int64 channel_id) { + if (partition_subgroups.size() <= 1) { + std::vector groups(num_replicas); + // TODO(yuanzx): Unify subgroup definition with AllToAll. + for (int64 i = 0; i < num_replicas; ++i) { + groups[i].add_replica_ids(i); + } + return b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction, groups, + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/false)); + } + + std::vector device_groups; + device_groups.reserve(partition_subgroups.size() * num_replicas); + for (int64 i = 0; i < num_replicas; ++i) { + for (const auto& pgroup : partition_subgroups) { + device_groups.emplace_back(); + for (int64 pid : pgroup) { + device_groups.back().add_replica_ids(i * num_partitions + pid); + } + } + } return b->AddInstruction(HloInstruction::CreateAllReduce( - operand->shape(), {operand}, reduction, - CreateReplicaGroups(num_replicas), + operand->shape(), {operand}, reduction, device_groups, /*constrain_layout=*/false, channel_id, - /*use_global_device_ids=*/false)); + /*use_global_device_ids=*/true)); }, [](SpmdBuilder* b, HloInstruction* operand, std::vector>& src_dst_pairs, @@ -4663,14 +3086,20 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions, operand->shape(), operand, src_dst_pairs, channel_id)); }, [](SpmdBuilder* b, absl::Span operands, - const std::vector& replica_groups, int64 channel_id, - absl::optional split_dimension) { + const std::vector>& partition_subgroups, + int64 channel_id, absl::optional split_dimension) { std::vector shapes(operands.size(), operands[0]->shape()); const Shape output_shape = (shapes.size() == 1) ? shapes[0] : ShapeUtil::MakeTupleShape(shapes); + std::vector groups(partition_subgroups.size()); + for (int64 i = 0; i < groups.size(); ++i) { + for (int64 id : partition_subgroups[i]) { + groups[i].add_replica_ids(id); + } + } return b->AddInstruction(HloInstruction::CreateAllToAll( - output_shape, operands, replica_groups, + output_shape, operands, groups, /*constrain_layout=*/false, channel_id, split_dimension)); }, [num_replicas, num_partitions]( @@ -4701,10 +3130,10 @@ SpmdPartitioner::SpmdPartitioner(int64 num_partitions, int64 num_replicas, num_partitions, num_replicas, std::move(options), GetDefaultCollectiveOpsCreator(num_partitions, num_replicas)) {} -HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b, - HloInstruction* operand, - const HloSharding& sharding, - int64 channel_id) { +HloInstruction* SpmdPartitioner::AllGatherShards( + SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, + int64 channel_id, absl::Span selected_dims, + const SPMDCollectiveOpsCreator& collectives_creator) { CHECK(!sharding.IsTileMaximal()); // Add one leading dimension to gather all partitions. std::vector shape; @@ -4714,18 +3143,17 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b, } auto reshape = b->AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(operand->shape().element_type(), shape), operand)); - std::vector> partition_subgroups(1); - for (int64 pid : sharding.tile_assignment()) { - partition_subgroups[0].push_back(pid); - } - shape[0] = sharding.tile_assignment().num_elements(); - auto result = collective_ops_creator_.create_cross_partition_all_gather( + auto partition_subgroups = + GetPartitionGroupsForReplication(sharding, selected_dims); + shape[0] = partition_subgroups[0].size(); + auto result = collectives_creator.create_cross_partition_all_gather( b, reshape, ShapeUtil::MakeShape(operand->shape().element_type(), shape), partition_subgroups, channel_id, /*all_gather_dimension=*/0); // If n > 1 dimensions are partitioned, split the leading dimension to n. std::vector tiled_dims; for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { - if (sharding.tile_assignment().dim(i) > 1) { + if (sharding.tile_assignment().dim(i) > 1 && + absl::c_linear_search(selected_dims, i)) { tiled_dims.push_back(i); } } @@ -4747,7 +3175,8 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b, std::vector xpose_permutation(result->shape().rank()); int64 split_dims_added = 0; for (int64 i = 0; i < xpose_permutation.size(); ++i) { - if (sharding.tile_assignment().dim(i - split_dims_added) == 1) { + if (sharding.tile_assignment().dim(i - split_dims_added) == 1 || + !absl::c_linear_search(selected_dims, i - split_dims_added)) { xpose_permutation[i] = i + tiled_dims.size() - split_dims_added; } else { xpose_permutation[i] = split_dims_added; diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index 52e4c9021d8..a612c16bdae 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -82,8 +83,10 @@ struct SPMDCollectiveOpsCreator { std::function create_partition_id; // Function used to create a cross-partition all-reduce HLO. - std::function + std::function>& partition_subgroups, + int64 channel_id)> create_cross_partition_all_reduce; // Function used to create a cross-partition collective-permute HLO. @@ -96,8 +99,8 @@ struct SPMDCollectiveOpsCreator { // Function used to create a cross-partition all-to-all HLO. std::function operands, - const std::vector& replica_groups, int64 channel_id, - absl::optional split_dimension)> + const std::vector>& partition_subgroups, + int64 channel_id, absl::optional split_dimension)> create_cross_partition_all_to_all; // Function used to create a cross-partition all-gather HLO. This is optional: @@ -169,10 +172,13 @@ class SpmdPartitioner : public HloModulePass { // The default uses a single all-gather even if there are multiple sharded // dimensions, and adds potential reshapes and transposes to achieve that. // If it returns false, the partitioner will fall back to all-reduce. - virtual HloInstruction* AllGatherShards(SpmdBuilder* b, - HloInstruction* operand, - const HloSharding& sharding, - int64 channel_id); + // `selected_dims` specifies the dimensions along which the all-gather happens + // in the tiled sharding, which allows potentially creating a subgroup + // all-gather. + virtual HloInstruction* AllGatherShards( + SpmdBuilder* b, HloInstruction* operand, const HloSharding& sharding, + int64 channel_id, absl::Span selected_dims, + const SPMDCollectiveOpsCreator& collectives_creator); protected: virtual std::unique_ptr CreateVisitor( @@ -215,7 +221,12 @@ class PartitionedHlo { std::tuple> window_reshard_cache; }; + // Use std::unordered_map for pointer stability. std::unordered_map per_hlo_cache; + // Caches for nested partitioning of grouped sharding. Each string key + // represents a unique way of grouping devices. + absl::flat_hash_map> + groupd_caches; }; struct PartitioningState { SpmdBuilder* b; @@ -270,21 +281,26 @@ class PartitionedHlo { const PartitioningState& state() const { return state_; } + // Helper function to replicate the data on all devices. Could only modify + // the reshard cache. + PartitionedHlo Replicate(); + + // Helper function to replicate the data for partitions along the given dims. + HloInstruction* ReplicatePartial(absl::Span dims); + private: // Same as Reshard except that it does not explicitly modify the reshard // cache, although it would indirectly modify by calling Replicate(). PartitionedHlo ReshardNoCache(const HloSharding& target); - // Helper function to replicate the data on all devices. Could only modify - // the reshard cache. - PartitionedHlo Replicate(); - // Helper function to broadcast data from a single device to all devices. PartitionedHlo Broadcast() const; // Helper function to reshard the tensor using AllToAll (instead of the // default of Replicate followed by Slice). - PartitionedHlo ReshardWithAllToAll(const HloSharding& target) const; + PartitionedHlo ReshardWithAllToAll( + const HloSharding& target, + absl::Span> source_target_dims) const; // Helper function to reshard the tensor using CollectivePermute. PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; @@ -314,6 +330,22 @@ struct DotGeneralDimsMapping { std::vector rhs_non_contracting_dims; }; +struct ConvolutionDimsMapping { + // The dimension numbers for the operands and output corresponding to a + // logical dimension (e.g., batch, parallel, non-parallel). If an + // operand or the output doesn't have the logical dimension, it is set to + // -1. + struct DimsMapping { + int64 lhs; + int64 rhs; + int64 output; + // input mapped to index in input_spatial_dimensions(). + int64 spatial; + }; + std::vector parallel_spatial_dims; + std::vector non_parallel_spatial_dims; +}; + class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { public: SpmdPartitioningVisitor( @@ -354,9 +386,6 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { Status HandleIota(HloInstruction* hlo) override; Status HandlePartitionId(HloInstruction* hlo) override; - // Handles convolution where both LHS and RHS operands are tiled. - Status HandleConvolutionTiledLhsAndRhs(HloInstruction* hlo); - // Implementation of dot partitioning given DotGeneralDimsMapping. Status HandleDotHelper( HloInstruction* hlo, const DotGeneralDimsMapping& dims_mapping, @@ -415,6 +444,16 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { StatusOr DoPartition(HloComputation* computation, const HloSharding& root_sharding); + // Information about a loop created for windowed dot-general. Used when + // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor + // finishes traversing the graph. + struct WindowedDotGeneralLoop { + HloInstruction* while_loop; + int64 windowed_operand; + bool windowed_in_contracting_dims; + bool windowed_in_batch_dims; + }; + private: Status Preprocess(HloInstruction* hlo) override; Status Postprocess(HloInstruction* hlo) override; @@ -443,15 +482,6 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault { // partitioned instruction. ConstHloInstructionMap partitioned_instructions_; - // Information about a loop created for windowed dot-general. Used when - // DoCodeMotionForWindowedDotGeneralLoops() executes after the visitor - // finishes traversing the graph. - struct WindowedDotGeneralLoop { - HloInstruction* while_loop; - int64 windowed_operand; - bool windowed_in_contracting_dims; - bool windowed_in_batch_dims; - }; std::vector windowed_dot_general_loops_; HloInstruction* visiting_hlo_; diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 1f0b1d06c1f..3ffe2954d61 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -380,6 +380,43 @@ ENTRY entry { op::GetTupleElement(second_infeed)))); } +TEST_F(SpmdPartitioningTest, MixedTupleInfeed) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + token0 = token[] after-all(), sharding={maximal device=0} + infeed = ((f32[9,2]{1,0}, f32[2]{0}), token[]) infeed(token0), + sharding={{maximal device=0}, {maximal device=1}, {maximal device=0}} + ROOT infeed.data = (f32[9,2]{1,0}, f32[2]{0}) get-tuple-element(infeed), + index=0, sharding={{maximal device=0}, {maximal device=1}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/2)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("(f32[9,2], f32[2])"), + op::GetTupleElement(op::Conditional( + op::Convert(op::PartitionId()), op::AfterAll(), + op::AfterAll())))); + auto first_infeed = AllOf(op::Shape("((f32[9,2], ()), token[])"), + op::Infeed(op::Parameter())); + EXPECT_THAT(root->operand(0)->called_computations()[0]->root_instruction(), + AllOf(op::Shape("((f32[9,2], f32[2]), token[])"), + op::Tuple(op::Tuple(op::GetTupleElement( + op::GetTupleElement(first_infeed)), + op::Broadcast(op::Constant())), + op::GetTupleElement(first_infeed)))); + auto second_infeed = + AllOf(op::Shape("(((), f32[2]), token[])"), op::Infeed(op::Parameter())); + EXPECT_THAT(root->operand(0)->called_computations()[1]->root_instruction(), + AllOf(op::Shape("((f32[9,2], f32[2]), token[])"), + op::Tuple(op::Tuple(op::Broadcast(op::Constant()), + op::GetTupleElement(op::GetTupleElement( + second_infeed))), + op::GetTupleElement(second_infeed)))); +} + TEST_F(SpmdPartitioningTest, TiledToReplicatedReduce) { const char* const hlo_string = R"( HloModule module @@ -527,6 +564,80 @@ ENTRY entry { op::Constant()))))); } +TEST_F(SpmdPartitioningTest, + BroadcastBothOldAndNewDimsShardedPartiallySharded) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + param = f32[4,3] parameter(0), + sharding={devices=[1,2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate} + ROOT broadcast = f32[4,4,3] broadcast(param), dimensions={1,2}, + sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::Shape("f32[2,4,2]"), + op::Broadcast(AllOf(op::Shape("f32[4,2]"), op::Parameter(0))))); +} + +TEST_F(SpmdPartitioningTest, + ConvWithParallelDimAndNonParallelSpatialDimPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[32,12,12,24,32] parameter(0) + %lhs.copy = f32[32,12,12,24,32] copy(%lhs), + sharding={devices=[2,2,1,1,1]0,1,2,3} + %rhs = f32[32,6,6,16,32] parameter(1) + %rhs.copy = f32[32,6,6,16,32] copy(%rhs), + sharding={devices=[2,2,1,1,1]0,1,2,3} + ROOT %conv = f32[32,7,7,24,16] convolution(%lhs.copy, %rhs.copy), + dim_labels=012bf_012oi->012bf, + window={size=32x6x6 stride=31x1x1 lhs_dilate=32x1x1}, + sharding={devices=[2,2,1,1,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[16,6,12,24,32]")); + auto rhs = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Reshape(), op::Constant(), + op::Constant(), op::Constant())), + op::Shape("f32[16,3,6,16,32]")); + auto resharded_rhs = + AllOf(op::Shape("f32[16,6,6,16,32]"), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), rhs, op::Constant(), op::Reshape(), + op::Constant(), op::Constant(), op::Constant()))); + + auto left_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[16,2,12,24,32]")); + auto right_halo = AllOf(op::CollectivePermute(op::Slice(lhs)), + op::Shape("f32[16,3,12,24,32]")); + EXPECT_THAT( + root, + AllOf(op::Convolution( + op::Select(op::Compare(), + op::DynamicSlice( + op::Concatenate(left_halo, lhs, right_halo), + op::Constant(), op::Add(), op::Constant(), + op::Constant(), op::Constant()), + op::Broadcast()), + resharded_rhs), + op::Shape("f32[16,4,7,24,16]"))); +} + TEST_F(SpmdPartitioningTest, BroadcastPropagateTiledSharding) { const char* const hlo_string = R"( HloModule module @@ -1399,6 +1510,50 @@ ENTRY entry { op::Shape("f32[1,1,512,64]"))); } +TEST_F(SpmdPartitioningTest, + ConvolutionLhsTiledRhsTiled_UnevenDilatedRHSPartitioned) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[8,28,28,8] parameter(0) + %lhs.copy = f32[8,28,28,8] copy(%lhs), sharding={devices=[1,4,1,1]0,1,2,3} + %rhs = f32[8,14,14,64] parameter(1) + %rhs.copy = f32[8,14,14,64] copy(%rhs), sharding={devices=[1,4,1,1]0,1,2,3} + ROOT %conv = f32[1,1,8,64] convolution(%lhs.copy, %rhs.copy), + window={size=14x14 pad=0_-1x0_-1 rhs_dilate=2x2}, + dim_labels=f01b_i01o->01bf, sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto lhs = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[8,7,28,8]")); + auto rhs = AllOf(op::Pad(op::Parameter(), op::Constant()), + op::Shape("f32[8,16,14,64]")); + auto selected_rhs = AllOf( + op::Select(op::Compare(), + op::Copy(op::DynamicSlice(rhs, op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Broadcast()), + op::Shape("f32[8,4,14,64]")); + auto right_halo = + AllOf(op::CollectivePermute(op::Slice(lhs)), op::Shape("f32[8,2,28,8]")); + auto selected_lhs = + AllOf(op::DynamicSlice( + op::Pad(op::Concatenate(lhs, right_halo), op::Constant()), + op::Constant(), op::Reshape(), op::Constant(), op::Constant()), + op::Shape("f32[8,7,28,8]")); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Convolution(selected_lhs, selected_rhs)), + op::Shape("f32[1,1,8,64]"))); +} + TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWithPadding) { const char* const hlo_string = R"( HloModule module @@ -2218,7 +2373,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); - std::cout << module->ToString(); + VLOG(1) << module->ToString(); auto sort = FindInstruction(module.get(), "sort"); EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664); EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664); @@ -2294,7 +2449,7 @@ ENTRY entry TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); - std::cout << module->ToString(); + VLOG(1) << module->ToString(); auto sort = FindInstruction(module.get(), "sort"); EXPECT_EQ(sort->operand(0)->shape().dimensions(1), 209664); EXPECT_EQ(sort->operand(1)->shape().dimensions(1), 209664); @@ -2612,6 +2767,35 @@ ENTRY entry { AllOf(op::Reduce(param0, op::Constant()), op::Shape("f32[64]"))); } +TEST_F(SpmdPartitioningTest, PartialTiledToPartialTiledReduce) { + const char* const hlo_string = R"( +HloModule module + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add = f32[] add(a, b) +} + +ENTRY entry { + %param0 = f32[4,4] parameter(0), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %constant.1 = f32[] constant(0), sharding={replicated} + ROOT %reduce = f32[4] reduce(%param0, %constant.1), dimensions={0}, + to_apply=%sum, + sharding={devices=[2,4]0,1,4,5,2,3,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::AllReduce(op::Reduce(op::Parameter(0), op::Constant())), + op::Shape("f32[2]"))); +} + TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce) { const char* const hlo_string = R"( HloModule module @@ -3576,6 +3760,25 @@ ENTRY entry { op::Shape("f32[3,5]"))); } +TEST_F(SpmdPartitioningTest, IndexPassthroughGather) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[2,9,8] parameter(0), sharding={replicated} + %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3} + ROOT %gather = f32[8,4,4] gather(%input, %indices), offset_dims={0}, + collapsed_slice_dims={0,1}, start_index_map={0,1}, index_vector_dim=1, + slice_sizes={1,1,8}, sharding={devices=[1,2,2]0,1,2,3} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Gather(op::Parameter(0), op::Parameter(1)), + op::Shape("f32[8,2,2]"))); +} + TEST_F(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) { const char* const hlo_string = R"( HloModule module @@ -3635,6 +3838,74 @@ ENTRY entry { op::Shape("f32[2,5]"))); } +TEST_F(SpmdPartitioningTest, IndexPassthroughScatter) { + const char* const hlo_string = R"( +HloModule module + +add (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT sum = f32[] add(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9,8] parameter(0), sharding={replicated} + %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3} + %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3} + ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates), + to_apply=add, + update_window_dims={2}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::AllReduce(op::Scatter( + op::Select(op::Broadcast(op::Convert(op::PartitionId())), + op::Broadcast(op::Constant()), op::Parameter(0)), + op::Parameter(1), op::Parameter(2))), + op::Shape("f32[2,9,8]"))); +} + +TEST_F(SpmdPartitioningTest, IndexPassthroughScatter_Min) { + const char* const hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[2,9,8] parameter(0), sharding={replicated} + %indices = s32[4,2,4] parameter(1), sharding={devices=[2,1,2]0,1,2,3} + %updates = f32[4,4,8] parameter(2), sharding={devices=[2,2,1]0,1,2,3} + ROOT %scatter = f32[2,9,8] scatter(%input, %indices, %updates), + to_apply=min, + update_window_dims={2}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1, sharding={replicated} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, + AllOf(op::AllReduce(op::Scatter( + op::Select(op::Broadcast(op::Convert(op::PartitionId())), + op::Broadcast(op::Constant()), op::Parameter(0)), + op::Parameter(1), op::Parameter(2))), + op::Shape("f32[2,9,8]"))); +} + TEST_F(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) { const char* const hlo_string = R"( HloModule module @@ -3766,6 +4037,364 @@ ENTRY entry { op::Parameter(0)))); } +TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8,8,8] parameter(0), + sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7} + ROOT %copy = f32[8,8,8,8] copy(%param0), + sharding={devices=[1,2,2,2]0,1,4,5,2,3,6,7} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto reshape = + AllOf(op::Shape("f32[4,4,2,4,4]"), op::Reshape(op::Parameter(0))); + auto all_to_all = AllOf(op::Shape("f32[4,4,2,4,4]"), op::AllToAll(reshape)); + auto xpose = AllOf(op::Shape("f32[2,4,4,4,4]"), op::Transpose(all_to_all)); + EXPECT_THAT(root, + op::Copy(AllOf(op::Reshape(xpose), op::Shape("f32[8,4,4,4]")))); + EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->replica_groups().size(), + 4); +} + +TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0), + sharding={devices=[2,4]0,1,2,3,4,5,6,7} + ROOT %copy = f32[8,8] copy(%param0), + sharding={devices=[4,2]0,1,4,5,2,3,6,7} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto all_to_all = op::AllToAll( + AllOf(op::Shape("f32[2,2,2]"), op::Reshape(op::Parameter(0)))); + auto reshape = + AllOf(op::Shape("f32[2,4]"), op::Reshape(op::Transpose(all_to_all))); + EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape))); +} + +TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard3) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8,8] parameter(0), + sharding={devices=[2,4,1]0,1,2,3,4,5,6,7} + ROOT %copy = f32[8,8,8] copy(%param0), + sharding={devices=[1,2,4]0,1,4,5,2,3,6,7} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto all_to_all = op::AllToAll( + AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(op::Parameter(0)))); + auto reshape = + AllOf(op::Shape("f32[4,8,2]"), op::Reshape(op::Transpose(all_to_all))); + auto all_to_all2 = + op::AllToAll(AllOf(op::Shape("f32[4,2,4,2]"), op::Reshape(reshape))); + auto reshape2 = + AllOf(op::Shape("f32[8,4,2]"), op::Reshape(op::Transpose(all_to_all2))); + EXPECT_THAT(root, op::Copy(op::CollectivePermute(reshape2))); +} + +TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting0) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[48,12] parameter(0), sharding={devices=[2,2]0,1,2,3} + %rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,1,2,3} + ROOT %dot = f32[48,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[24,6]"), op::Parameter(0)); + auto partial_replicated_lhs = + AllOf(op::Shape("f32[24,12]"), + op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _))); + auto rhs = AllOf(op::Shape("f32[16,6]"), op::Parameter(1)); + auto partial_replicated_rhs = + AllOf(op::Shape("f32[16,12]"), op::AllReduce(op::DynamicUpdateSlice( + _, op::CollectivePermute(rhs), _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Dot(partial_replicated_lhs, partial_replicated_rhs), + op::Shape("f32[24,16]"))); +} + +TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting1) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[48,100] parameter(0), sharding={devices=[2,2]0,1,2,3} + %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3} + ROOT %dot = f32[48,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1)); + auto partial_replicated_rhs = + AllOf(op::Shape("f32[32,50]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT( + root, AllOf(op::Shape("f32[24,16]"), + op::DynamicSlice( + op::AllReduce(AllOf(op::Dot(lhs, partial_replicated_rhs), + op::Shape("f32[24,32]"))), + _, _))); +} + +TEST_F(SpmdPartitioningTest, Dot2DPartitionedNonContractingAndContracting2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[48,100] parameter(0), sharding={replicated} + %rhs = f32[32,100] parameter(1), sharding={devices=[2,2]0,1,2,3} + ROOT %dot = f32[48,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[48,100]"), op::Parameter(0)); + auto lhs_slice = AllOf(op::Shape("f32[24,100]"), op::DynamicSlice(lhs, _, _)); + auto rhs = AllOf(op::Shape("f32[16,50]"), op::Parameter(1)); + auto partial_replicated_rhs = AllOf( + op::Shape("f32[16,100]"), op::AllReduce(op::DynamicUpdateSlice( + _, op::CollectivePermute(rhs), _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[24,16]"), + op::Dot(lhs_slice, partial_replicated_rhs))); +} + +TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndNonContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3} + %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3} + ROOT %dot = f32[4,24,32] dot(%lhs, %rhs), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[2,12,100]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1)); + auto partial_replicated_rhs = + AllOf(op::Shape("f32[2,32,100]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"), + op::Dot(lhs, partial_replicated_rhs))); +} + +TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3} + %rhs = f32[4,32,100] parameter(1), sharding={devices=[1,2,2]0,1,2,3} + ROOT %dot = f32[4,24,32] dot(%lhs, %rhs), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[4,16,50]"), op::Parameter(1)); + auto resharded_rhs = + AllOf(op::Shape("f32[2,32,50]"), + op::Reshape(op::Transpose(op::AllToAll(op::Reshape(rhs))))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"), + op::DynamicSlice( + AllOf(op::Shape("f32[2,24,32]"), + op::AllReduce(op::Dot(lhs, resharded_rhs))), + _, _, _))); +} + +TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndContracting2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3} + %rhs = f32[4,32,100] parameter(1), sharding={replicated} + ROOT %dot = f32[4,24,32] dot(%lhs, %rhs), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0)); + auto resharded_lhs = + AllOf(op::Shape("f32[2,12,100]"), + op::Reshape(op::Transpose(op::AllToAll(op::Reshape(lhs))))); + auto rhs = AllOf(op::Shape("f32[4,32,100]"), op::Parameter(1)); + auto rhs_slice = + AllOf(op::Shape("f32[2,32,100]"), op::DynamicSlice(rhs, _, _, _)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,12,32]"), + op::Dot(resharded_lhs, rhs_slice))); +} + +TEST_F(SpmdPartitioningTest, + Dot2DPartitionedBatchNonContractingAndContracting) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[4,24,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3} + %rhs = f32[4,32,100] parameter(1), sharding={devices=[2,2,1]0,1,2,3} + ROOT %dot = f32[4,24,32] dot(%lhs, %rhs), + lhs_batch_dims={0}, rhs_batch_dims={0}, + lhs_contracting_dims={2}, rhs_contracting_dims={2}, + sharding={devices=[2,1,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[2,24,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[2,16,100]"), op::Parameter(1)); + auto partial_replicated_lhs = + AllOf(op::Shape("f32[2,24,100]"), + op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,24,16]"), + op::Dot(partial_replicated_lhs, rhs))); +} + +TEST_F(SpmdPartitioningTest, Dot2DPartitionedBatchAndReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[4,8,24,100] parameter(0), sharding={devices=[2,1,2,1]0,1,2,3} + %rhs = f32[4,8,32,100] parameter(1), sharding={devices=[2,1,2,1]0,1,2,3} + ROOT %dot = f32[4,8,24,32] dot(%lhs, %rhs), + lhs_batch_dims={0,1}, rhs_batch_dims={0,1}, + lhs_contracting_dims={3}, rhs_contracting_dims={3}, + sharding={devices=[1,2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[2,8,12,100]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[2,8,16,100]"), op::Parameter(1)); + auto partial_replicated_rhs = + AllOf(op::Shape("f32[2,8,32,100]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _, _, _))); + auto dot = + AllOf(op::Shape("f32[2,8,12,32]"), op::Dot(lhs, partial_replicated_rhs)); + auto reshape = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Reshape(dot)); + auto all_to_all = AllOf(op::Shape("f32[2,2,4,12,32]"), op::AllToAll(reshape)); + auto xpose = AllOf(op::Shape("f32[2,2,4,12,32]"), op::Transpose(all_to_all)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[4,4,12,32]"), op::Reshape(xpose))); +} + +TEST_F(SpmdPartitioningTest, + ElementwiseTest_PartialReplicateToTiledHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + constant = f32[6,3]{1,0} + constant({{1,3,7},{5,1,4},{1,2,8},{2,3,7},{5,2,4},{2,2,8}}), + sharding={replicated} + constant.1 = f32[6,3]{1,0} + constant({{2,7,2},{2,9,2},{2,6,2},{3,7,2},{2,9,3},{2,3,2}}), + sharding={replicated} + multiply = f32[6,3]{1,0} multiply(constant, constant.1), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT add = f32[6,3]{1,0} add(multiply, constant.1), + sharding={devices=[4,1]0,1,2,3} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto partial_replicate_lhs = + AllOf(op::Shape("f32[3,3]"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto partial_replicate_rhs = + AllOf(op::Shape("f32[3,3]"), + op::DynamicSlice(op::Constant(), op::Reshape(), op::Constant())); + auto multiply = + AllOf(op::Shape("f32[3,3]"), + op::Multiply(partial_replicate_lhs, partial_replicate_rhs)); + auto right_halo = + AllOf(op::Shape("f32[1,3]"), op::CollectivePermute(op::Slice(multiply))); + auto add_lhs = AllOf( + op::Shape("f32[2,3]"), + op::DynamicSlice( + op::DynamicSlice( + op::Pad(op::Concatenate(multiply, right_halo), op::Constant()), + op::Reshape(), op::Constant()), + op::Reshape(), op::Constant())); + auto add_rhs = AllOf(op::Shape("f32[2,3]"), + op::DynamicSlice(op::Pad(op::Constant(), op::Constant()), + op::Reshape(), op::Constant())); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]"), op::Add(add_lhs, add_rhs))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc index 3354a9c3233..3443c6e013d 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -16,7 +16,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h" #include +#include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -25,10 +30,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_util.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/service/spmd/spmd_partitioner.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { @@ -41,6 +49,23 @@ bool HasReplicatedSharding(const HloSharding& sharding) { return sharding.IsReplicated(); } +HloInstruction* CreateConstant(const Shape& shape, Literal value, + SpmdBuilder* b) { + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + elements.push_back(CreateConstant( + ShapeUtil::GetTupleElementShape(shape, i), value.Clone(), b)); + } + return b->AddInstruction(HloInstruction::CreateTuple(elements)); + } + + CHECK( + ShapeUtil::IsScalarWithElementType(value.shape(), shape.element_type())); + auto c = b->AddInstruction(HloInstruction::CreateConstant(std::move(value))); + return b->AddInstruction(HloInstruction::CreateBroadcast(shape, c, {})); +} + HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { if (shape.IsTuple()) { std::vector elements; @@ -59,6 +84,24 @@ HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b) { return b->AddInstruction(HloInstruction::CreateBroadcast(shape, zero, {})); } +HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b) { + if (shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + elements.push_back( + CreateOne(ShapeUtil::GetTupleElementShape(shape, i), b)); + } + return b->AddInstruction(HloInstruction::CreateTuple(elements)); + } + + if (shape.IsToken()) { + return b->AddInstruction(HloInstruction::CreateToken()); + } + auto one = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::One(shape.element_type()))); + return b->AddInstruction(HloInstruction::CreateBroadcast(shape, one, {})); +} + HloComputation* MakeBinaryAdd(PrimitiveType type, HloModule* module) { HloComputation::Builder sum_b("add"); auto x = sum_b.AddInstruction(HloInstruction::CreateParameter( @@ -128,6 +171,16 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, return ShapeUtil::MakeTupleShape(subshapes); } + if (sharding.IsReplicated()) { + return shape; + } + if (sharding.IsTileMaximal()) { + if (partition_id == *sharding.UniqueDevice()) { + return shape; + } + return ShapeUtil::MakeTupleShape({}); + } + auto partition_shape = shape; std::vector tile_offset = sharding.TileOffsetForDevice(shape, partition_id); @@ -143,10 +196,10 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, return partition_shape; } -std::vector MakePartitionOffsets(const Shape& shape, - const HloSharding& sharding, - HloInstruction* partition_id, - SpmdBuilder* b) { +std::vector MakePartitionOffsets( + const Shape& shape, const HloSharding& sharding, + HloInstruction* partition_id, SpmdBuilder* b, + absl::Span dims) { CHECK(!shape.IsTuple()); Array2D offset_array( @@ -158,7 +211,8 @@ std::vector MakePartitionOffsets(const Shape& shape, LiteralUtil::CreateR2FromArray2D(offset_array))); std::vector offsets; for (int64 i = 0; i < shape.rank(); ++i) { - if (sharding.tile_assignment().dim(i) == 1) { + if (sharding.tile_assignment().dim(i) == 1 || + (!dims.empty() && !absl::c_linear_search(dims, i))) { offsets.push_back(b->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(S32)))); } else { @@ -177,8 +231,11 @@ std::vector MakePartitionOffsets(const Shape& shape, std::vector MakeTiledPartitionOrdinals( const HloSharding& sharding, HloInstruction* partition_id, SpmdBuilder* b) { CHECK(!sharding.IsTileMaximal()); - auto table_shape = - ShapeUtil::MakeShape(S32, sharding.tile_assignment().dimensions()); + auto dimensions = sharding.tile_assignment().dimensions(); + if (sharding.ReplicateOnLastTileDim()) { + dimensions.pop_back(); + } + auto table_shape = ShapeUtil::MakeShape(S32, dimensions); return MakePartitionOffsets(table_shape, sharding, partition_id, b); } @@ -235,6 +292,195 @@ HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( return PadToShape(hlo, padded_base_shape, b); } +// TODO(wangtao): generize this function when target is partial replicate. +absl::optional PartialReplicateToTileCompatibleSharding( + const HloSharding& partial_sharding, + const std::vector& target_tile_dims) { + if (!partial_sharding.ReplicateOnLastTileDim()) { + return absl::nullopt; + } + int64 rank = partial_sharding.tile_assignment().num_dimensions() - 1; + if (target_tile_dims.size() < rank) { + return absl::nullopt; + } + // A dimension is expanded when target_tile_size > partial_tile_size and + // target_tile_size % partial_tile_size == 0. + // expand_tile_dims_positions is the index of the expand_dim. + std::vector expand_tile_dims_indices(rank, -1); + // expand_tile_size = target_tile_size / partial_tile_size. + std::vector expand_tile_sizes; + int num_expand_dims = 0; + for (int64 dim = 0; dim < rank; dim++) { + int64 partial_tile_size = partial_sharding.tile_assignment().dim(dim); + int64 target_tile_size = target_tile_dims[dim]; + if (target_tile_size % partial_tile_size != 0 || + target_tile_size < partial_tile_size) { + return absl::nullopt; + } + + if (target_tile_size > partial_tile_size) { + expand_tile_dims_indices[dim] = num_expand_dims++; + expand_tile_sizes.emplace_back(target_tile_size / partial_tile_size); + } + } + + // Reshape the partial replicate tile_dimensions. + auto reshape_dimensions = partial_sharding.tile_assignment().dimensions(); + int64 num_replication = reshape_dimensions.back(); + if (num_replication != Product(expand_tile_sizes)) { + return absl::nullopt; + } + reshape_dimensions.pop_back(); + reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(), + expand_tile_sizes.end()); + auto reshape_tile_assignment = partial_sharding.tile_assignment(); + + // Transpose. + std::vector perm; + perm.reserve(rank); + for (int64 dim = 0; dim < rank; dim++) { + perm.emplace_back(dim); + if (expand_tile_dims_indices[dim] > -1) { + perm.emplace_back(expand_tile_dims_indices[dim] + rank); + } + } + auto transpose_sharding = hlo_sharding_util::TransposeSharding( + HloSharding::Tile(reshape_tile_assignment), perm); + + // Reshape to target shape + auto transpose_tile_assignment = transpose_sharding.tile_assignment(); + transpose_tile_assignment.Reshape(target_tile_dims); + + return HloSharding::Tile(transpose_tile_assignment); +} + +absl::optional PadFromPartialReplicateShape( + HloInstruction* hlo, const Shape& base_shape, + const HloSharding& src_sharding, const HloSharding& dst_sharding, + const std::vector& expand_tile_dims, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b) { + auto padded_src_shape = + GetPaddedShapeForUnevenPartitioning(base_shape, src_sharding); + auto padded_dst_shape = + GetPaddedShapeForUnevenPartitioning(base_shape, dst_sharding); + if (ShapeUtil::Compatible(padded_dst_shape, hlo->shape())) { + return hlo; + } + + auto partition_ordinals = + MakeTiledPartitionOrdinals(src_sharding, partition_id, b); + + HloInstruction* result = hlo; + auto zero = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(hlo->shape().element_type()))); + std::vector expand_dims_without_halo_exchange; + // Pad the dimensions needs halo exchange and record the padded dims that + // won't need halo exchange. + for (auto dim : expand_tile_dims) { + int64 src_shard_count = src_sharding.tile_assignment().dim(dim); + int64 src_per_shard_size = + padded_src_shape.dimensions(dim) / src_shard_count; + // Calculate per shard size using the sharding to compare if dst_sharding + // needs more padding at the end. + int64 dst_per_shard_size = + padded_dst_shape.dimensions(dim) / src_shard_count; + + // If dst_sharding doesn't need more padding at the end. + if (src_per_shard_size >= dst_per_shard_size) { + continue; + } + // If src sharding at this dimension is not partitoned, simply pad to + // the desired shape. + if (src_shard_count == 1) { + expand_dims_without_halo_exchange.emplace_back(dim); + continue; + } + + // If dst_padding needs more padding at the end, need to re-distribute the + // data between each shard using collective permute. + // For example, if dimension size is 6 and shard 2 ways in the src but + // needs to shard 4 ways in the dst. 4 ways needs padding 2 0s at the end + // and has 2 elements at each shard, while 2 way sharding has 3 elements + // in each shard, re-distribution is needed. + // + // 1. Calculate left_halo size. + // left-halo size is 0 + OffsetCalculation left_halo_size_function = + OffsetCalculation(MultiplyAddDivideOffsetCalculation(0, 0, 1)); + + // 2. Calculate right_halo size. + // right-halo size is D * (i + 1) - S * (i + 1) = (D - S) * i + (D - S) + OffsetCalculation right_halo_size_function = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + dst_per_shard_size - src_per_shard_size, + dst_per_shard_size - src_per_shard_size, 1)); + + auto concat = result; + // 3. Halo exchange. + auto halo_exchange_result = ExchangeHalo( + result, left_halo_size_function, right_halo_size_function, dim, + src_sharding, collective_ops_creator, next_channel_id, b); + + if (halo_exchange_result.has_value()) { + concat = halo_exchange_result.value(); + } else { + return absl::nullopt; + } + + // 4. Pad. + std::vector zero_padding(concat->shape().rank()); + PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding); + pad_config.mutable_dimensions(dim)->set_edge_padding_low(0); + int64 max_right_halo_size = + right_halo_size_function.MaxInRange(0, src_shard_count - 1); + pad_config.mutable_dimensions(dim)->set_edge_padding_high(std::max( + 0LL, padded_dst_shape.dimensions(dim) - + padded_src_shape.dimensions(dim) - max_right_halo_size)); + auto padded_concat_shape = ShapeInference::InferPadShape( + concat->shape(), zero->shape(), pad_config) + .ValueOrDie(); + concat = b->AddInstruction(HloInstruction::CreatePad( + padded_concat_shape, concat, zero, pad_config)); + + // 5. Slice the valid result. + // Slice offset is (D-S) * i + auto zero_s32 = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(S32))); + OffsetCalculation start_offset_on_padded_concat_calculation = + OffsetCalculation(MultiplyAddDivideOffsetCalculation( + dst_per_shard_size - src_per_shard_size, 0, 1)); + auto slice_shape = concat->shape(); + slice_shape.set_dimensions(dim, dst_per_shard_size); + std::vector slice_offsets(concat->shape().rank(), + zero_s32); + slice_offsets[dim] = start_offset_on_padded_concat_calculation.Calculate( + partition_ordinals[dim], b); + result = b->AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, concat, slice_offsets, slice_shape.dimensions())); + } + + // Pad other dimensions that won't need halo exchange with a single pad. + if (!expand_dims_without_halo_exchange.empty()) { + std::vector zero_padding(result->shape().rank()); + PaddingConfig pad_config = window_util::MakeSymmetricPadding(zero_padding); + + auto padded_shape = result->shape(); + for (auto dim : expand_dims_without_halo_exchange) { + pad_config.mutable_dimensions(dim)->set_edge_padding_low(0); + pad_config.mutable_dimensions(dim)->set_edge_padding_high( + padded_dst_shape.dimensions(dim) - padded_src_shape.dimensions(dim)); + padded_shape.set_dimensions(dim, result->shape().dimensions(dim) + + padded_dst_shape.dimensions(dim) - + padded_src_shape.dimensions(dim)); + } + result = b->AddInstruction( + HloInstruction::CreatePad(padded_shape, result, zero, pad_config)); + } + + return result; +} + absl::optional UniqueTiledDim(const HloSharding& sharding) { if (sharding.IsTileMaximal()) { return absl::nullopt; @@ -877,5 +1123,461 @@ HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder, output_shape, hlo, start_indices, limit_indices, strides)); } +// Check if a dimension is sharded. +int64 ShardCountAtDim(const HloSharding& sharding, int64 dim) { + if (sharding.IsTileMaximal()) { + return 1; + } + return sharding.tile_assignment().dim(dim); +} + +absl::optional>> +GetReshardAllToAllSourceTargetDims(const HloSharding& source, + const HloSharding& target) { + if (source.IsTileMaximal() || target.IsTileMaximal() || + source.tile_assignment().num_dimensions() != + target.tile_assignment().num_dimensions() || + source.NumTiles() != target.NumTiles()) { + return absl::nullopt; + } + // Record partition count to index for indices that have different partition + // counts on source and target. + std::map> source_size_to_dim; + std::map> target_size_to_dim; + for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) { + if (source.tile_assignment().dim(i) == target.tile_assignment().dim(i)) { + continue; + } + source_size_to_dim[source.tile_assignment().dim(i)].push_back(i); + target_size_to_dim[target.tile_assignment().dim(i)].push_back(i); + } + // In order to shard via AllToAll, source_size_to_dim and target_size_to_dim + // must have the same distribution. + if (source_size_to_dim.empty() || + source_size_to_dim.size() != target_size_to_dim.size()) { + return absl::nullopt; + } + for (const auto& entry : source_size_to_dim) { + auto target_it = target_size_to_dim.find(entry.first); + if (target_it == target_size_to_dim.end() || + target_it->second.size() != entry.second.size()) { + return absl::nullopt; + } + } + std::vector> result; + auto remove_entry = [](int64 size, int64 dim, + std::map>& size_to_dim) { + size_to_dim[size].erase( + std::remove_if(size_to_dim[size].begin(), size_to_dim[size].end(), + [dim](int64 a) { return a == dim; }), + size_to_dim[size].end()); + if (size_to_dim[size].empty()) { + size_to_dim.erase(size); + } + }; + // Find one pair of dimensions to swap at a time. + while (!source_size_to_dim.empty()) { + int64 source_size = source_size_to_dim.begin()->first; + int64 i = source_size_to_dim.begin()->second.back(); + int64 target_i_size = target.tile_assignment().dim(i); + if (target_i_size == source_size) { + remove_entry(source_size, i, source_size_to_dim); + remove_entry(source_size, i, target_size_to_dim); + continue; + } + auto j_it = source_size_to_dim[target_i_size].begin(); + int64 j = *j_it; + if (source_size == 1) { + // If possible, find a j where the target partition count is not one, so + // that when we swap, the resulting size-1 dimension will still be useful + // to other dimensions. + while (target.tile_assignment().dim(j) == 1) { + if (++j_it == source_size_to_dim[target_i_size].end()) { + break; + } + j = *j_it; + } + } else if (target_i_size % source_size == 0) { + // If possible, find a j where the target partition count is source_size, + // so that we can do a single swap. + while (target.tile_assignment().dim(j) != source_size) { + if (++j_it == source_size_to_dim[target_i_size].end()) { + break; + } + j = *j_it; + } + } else { + return absl::nullopt; + } + result.emplace_back(j, i); + remove_entry(target_i_size, i, target_size_to_dim); + source_size_to_dim.begin()->second.back() = j; + remove_entry(target_i_size, j, source_size_to_dim); + } + return result; +} + +bool CanReshardWithCollectivePermute(const HloSharding& source, + const HloSharding& target) { + return !source.IsTileMaximal() && !target.IsTileMaximal() && + source.tile_assignment().dimensions() == + target.tile_assignment().dimensions() && + source.ReplicateOnLastTileDim() == target.ReplicateOnLastTileDim() && + source.tile_assignment() != target.tile_assignment(); +} + +GroupedSharding GroupShardingOnDims(const HloSharding& sharding, + absl::Span group_dims) { + CHECK(!sharding.IsTileMaximal()); + std::vector grouped_tiling_dims = + sharding.tile_assignment().dimensions(); + std::vector group_dim_sizes(group_dims.size()); + for (int64 i = 0; i < group_dims.size(); ++i) { + group_dim_sizes[i] = grouped_tiling_dims[group_dims[i]]; + grouped_tiling_dims[group_dims[i]] = 1; + } + std::vector> device_groups(Product(group_dim_sizes)); + sharding.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + int64 group_id = 0; + for (int64 dim : group_dims) { + group_id *= sharding.tile_assignment().dim(dim); + group_id += indices[dim]; + } + device_groups[group_id].push_back(device); + }); + Array grouped_tiling(grouped_tiling_dims); + grouped_tiling.FillIota(0); + return GroupedSharding( + std::move(device_groups), + std::vector(group_dims.begin(), group_dims.end()), + std::move(group_dim_sizes), sharding.tile_assignment().num_dimensions(), + HloSharding::Tile(grouped_tiling)); +} + +HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) { + CHECK(!grouped_sharding.sharding.IsTileMaximal()); + std::vector tiling_dims = + grouped_sharding.sharding.tile_assignment().dimensions(); + for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) { + tiling_dims[grouped_sharding.group_dims[i]] = + grouped_sharding.group_dim_sizes[i]; + } + Array tiling(tiling_dims); + grouped_sharding.sharding.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + std::vector ungrouped_inds(indices.begin(), indices.end()); + for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) { + int64 remaining_group_index = g; + for (int64 i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) { + ungrouped_inds[grouped_sharding.group_dims[i]] = + remaining_group_index % grouped_sharding.group_dim_sizes[i]; + remaining_group_index /= grouped_sharding.group_dim_sizes[i]; + } + tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device]; + } + }); + return HloSharding::Tile(tiling); +} + +GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding, + const GroupedSharding& reference, + bool ignore_group_order) { + // Returns src -> dst index mapping. + auto get_permutation = [](absl::Span src, + absl::Span dst) { + CHECK_EQ(src.size(), dst.size()); + absl::flat_hash_map dst_reverse_map; + for (int64 i = 0; i < dst.size(); ++i) { + dst_reverse_map[dst[i]] = i; + } + std::vector permutation(src.size()); + for (int64 i = 0; i < src.size(); ++i) { + auto it = dst_reverse_map.find(src[i]); + CHECK(it != dst_reverse_map.end()); + permutation[i] = it->second; + } + return permutation; + }; + CHECK_EQ(grouped_sharding.device_groups.size(), + reference.device_groups.size()); + absl::flat_hash_map device_to_ref_group; + for (int64 g = 0; g < reference.device_groups.size(); ++g) { + for (int64 device : reference.device_groups[g]) { + device_to_ref_group[device] = g; + } + } + auto unique_ref_dev_group = [&](absl::Span devices) -> int64 { + int64 ref_g = -1; + for (int64 device : devices) { + if (ref_g == -1) { + ref_g = device_to_ref_group[device]; + } else if (ref_g != device_to_ref_group[device]) { + return -1; + } + } + return ref_g; + }; + bool matching_groups = true; + std::vector original_src_to_ref_permutation; + for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) { + int64 ref_g = unique_ref_dev_group(grouped_sharding.device_groups[g]); + if (ref_g < 0 || (!ignore_group_order && g != ref_g)) { + matching_groups = false; + break; + } + if (g == 0) { + original_src_to_ref_permutation = get_permutation( + grouped_sharding.device_groups[g], reference.device_groups[ref_g]); + } + } + if (matching_groups) { + auto tiles = grouped_sharding.sharding.tile_assignment(); + tiles.Each([&](absl::Span indices, int64* device) { + *device = original_src_to_ref_permutation[*device]; + }); + grouped_sharding.sharding = HloSharding::Tile(tiles); + } + grouped_sharding.device_groups = std::move(reference.device_groups); + return grouped_sharding; +} + +Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding, + const Shape& original_base_shape) { + auto result = original_base_shape; + for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) { + int64 dim = grouped_sharding.group_dims[i]; + int64 groups = grouped_sharding.group_dim_sizes[i]; + result.set_dimensions(dim, result.dimensions(dim) / groups); + } + return result; +} + +namespace { + +HloInstruction* GetInGroupPartitionId( + HloInstruction* partition_id, + const std::vector>& device_groups, SpmdBuilder* b) { + int64 total_devices = device_groups.size() * device_groups[0].size(); + std::vector in_group_ids(total_devices); + for (uint32 i = 0; i < device_groups.size(); ++i) { + for (uint32 j = 0; j < device_groups[i].size(); ++j) { + in_group_ids[device_groups[i][j]] = j; + } + } + auto id_table = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(in_group_ids))); + return b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeScalarShape(U32), + b->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(U32, {1}), id_table, {partition_id}, {1})))); +} + +SPMDCollectiveOpsCreator GetPerGroupCollectiveOpsCreator( + const SPMDCollectiveOpsCreator& creator, + const std::vector>& device_groups) { + SPMDCollectiveOpsCreator result; + result.create_partition_id = [creator, device_groups](SpmdBuilder* b) { + return GetInGroupPartitionId(creator.create_partition_id(b), device_groups, + b); + }; + auto expand_partition_groups = + [device_groups]( + const std::vector>& partition_subgroups) { + if (partition_subgroups.empty()) { + return device_groups; + } + std::vector> result(partition_subgroups.size() * + device_groups.size()); + for (int64 g = 0; g < device_groups.size(); ++g) { + for (int64 i = 0; i < partition_subgroups.size(); ++i) { + result[g * partition_subgroups.size() + i].resize( + partition_subgroups[i].size()); + for (int64 j = 0; j < partition_subgroups[i].size(); ++j) { + result[g * partition_subgroups.size() + i][j] = + device_groups[g][partition_subgroups[i][j]]; + } + } + } + return result; + }; + result.create_cross_partition_all_reduce = + [creator, expand_partition_groups]( + SpmdBuilder* b, HloInstruction* operand, HloComputation* reduction, + const std::vector>& partition_subgroups, + int64 channel_id) { + return creator.create_cross_partition_all_reduce( + b, operand, reduction, expand_partition_groups(partition_subgroups), + channel_id); + }; + result.create_cross_partition_collective_permute = + [creator, device_groups]( + SpmdBuilder* b, HloInstruction* operand, + std::vector>& src_dst_pairs, + int64 next_channel_id) { + std::vector> expanded_pairs( + src_dst_pairs.size() * device_groups.size()); + for (int64 g = 0; g < device_groups.size(); ++g) { + for (int64 i = 0; i < src_dst_pairs.size(); ++i) { + expanded_pairs[g * src_dst_pairs.size() + i] = + std::pair{ + device_groups[g][src_dst_pairs[i].first], + device_groups[g][src_dst_pairs[i].second]}; + } + } + return creator.create_cross_partition_collective_permute( + b, operand, expanded_pairs, next_channel_id); + }; + result.create_cross_partition_all_to_all = + [creator, expand_partition_groups]( + SpmdBuilder* b, absl::Span operands, + const std::vector>& partition_subgroups, + int64 channel_id, absl::optional split_dimension) { + return creator.create_cross_partition_all_to_all( + b, operands, expand_partition_groups(partition_subgroups), + channel_id, split_dimension); + }; + if (creator.create_cross_partition_all_gather) { + result.create_cross_partition_all_gather = + [creator, expand_partition_groups]( + SpmdBuilder* b, HloInstruction* operand, const Shape& ag_shape, + const std::vector>& partition_subgroups, + int64 channel_id, int64 all_gather_dimension) { + return creator.create_cross_partition_all_gather( + b, operand, ag_shape, + expand_partition_groups(partition_subgroups), channel_id, + all_gather_dimension); + }; + } + return result; +} + +} // namespace + +PartitionedHlo::PartitioningState CreatePerGroupPartitioningState( + const PartitionedHlo::PartitioningState& state, + const std::vector>& device_groups, SpmdBuilder* b) { + auto result = state; + result.collective_ops_creator = GetPerGroupCollectiveOpsCreator( + state.collective_ops_creator, device_groups); + result.partition_id = + GetInGroupPartitionId(state.partition_id, device_groups, b); + // Create a string key for the groups. + std::vector per_group_strings(device_groups.size()); + for (int64 i = 0; i < per_group_strings.size(); ++i) { + per_group_strings[i] = absl::StrJoin(device_groups[i], ","); + } + auto& grouped_cache = + state.reshard_cache->groupd_caches[absl::StrJoin(per_group_strings, ";")]; + if (!grouped_cache) { + grouped_cache = absl::make_unique(); + } + result.reshard_cache = grouped_cache.get(); + return result; +} + +HloInstruction* PerGroupSliceFromReplicated( + HloInstruction* replicated, HloInstruction* partition_id, + const std::vector>& device_groups, + absl::Span group_dims, absl::Span group_dim_sizes, + SpmdBuilder* b) { + std::vector group_ids(device_groups.size() * device_groups[0].size()); + for (int64 g = 0; g < device_groups.size(); ++g) { + for (int64 device : device_groups[g]) { + group_ids[device] = g; + } + } + auto group_id_table = b->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1(group_ids))); + auto group_id = b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeScalarShape(U32), + b->AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(U32, {1}), group_id_table, {partition_id}, + {1})))); + std::vector group_level_tile_dims(replicated->shape().rank(), 1); + for (int64 i = 0; i < group_dims.size(); ++i) { + group_level_tile_dims[group_dims[i]] = group_dim_sizes[i]; + } + Array group_level_tile(group_level_tile_dims); + group_level_tile.Each([&](absl::Span indices, int64* group) { + *group = 0; + for (int64 dim : group_dims) { + *group *= group_level_tile.dim(dim); + *group += indices[dim]; + } + }); + auto group_level_sharding = HloSharding::Tile(group_level_tile); + auto padded_hlo = PadBaseShapeBeforeUnevenTiledSharding( + replicated, group_level_sharding, b); + auto shard_shape = + MakePartitionedShape(replicated->shape(), group_level_sharding); + return b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo, + MakePartitionOffsets(replicated->shape(), group_level_sharding, group_id, + b), + shard_shape.dimensions())); +} + +absl::optional TransposeShardingWithCollapsedDims( + const HloSharding& source, absl::Span src_to_tgt, + absl::Span tgt_to_src) { + if (source.IsTileMaximal()) { + return source; + } + std::vector tgt_dims_skipping_new(tgt_to_src.size(), -1); + int64 skipped_tgt_dims = 0; + for (int64 i = 0; i < tgt_to_src.size(); ++i) { + if (tgt_to_src[i] < 0) { + skipped_tgt_dims++; + } else { + tgt_dims_skipping_new[i] = i - skipped_tgt_dims; + } + } + int64 skipped_src_dims = absl::c_count(src_to_tgt, -1); + std::vector perm(src_to_tgt.size()); + for (int64 i = 0; i < src_to_tgt.size(); ++i) { + if (src_to_tgt[i] < 0) { + if (source.tile_assignment().dim(i) > 1) { + return absl::nullopt; + } + perm[src_to_tgt.size() - skipped_src_dims] = i; + skipped_src_dims--; + } else { + perm[tgt_dims_skipping_new[src_to_tgt[i]]] = i; + } + } + auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); + if (skipped_tgt_dims == 0) { + return tgt_sharding; + } + auto reshape_tiles = tgt_sharding.tile_assignment(); + std::vector tgt_tiles(tgt_to_src.size(), 1); + for (int64 i = 0; i < tgt_tiles.size(); ++i) { + if (tgt_to_src[i] >= 0) { + tgt_tiles[i] = reshape_tiles.dim(tgt_dims_skipping_new[i]); + } + } + reshape_tiles.Reshape(tgt_tiles); + return HloSharding::Tile(reshape_tiles); +} + +absl::optional ParseReductionComputation( + const HloComputation* reduction_comp) { + if (reduction_comp->num_parameters() != 2) { + return absl::nullopt; + } + auto root = reduction_comp->root_instruction(); + if (!root->IsElementwiseBinary()) { + return absl::nullopt; + } + if (!absl::c_linear_search(root->operands(), + reduction_comp->parameter_instruction(0)) || + !absl::c_linear_search(root->operands(), + reduction_comp->parameter_instruction(1))) { + return absl::nullopt; + } + return root->opcode(); +} + } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h index 5f245667970..6906b52ca79 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -33,9 +33,16 @@ namespace spmd { // Returns true if the given sharding contains any replicated sharding. bool HasReplicatedSharding(const HloSharding& sharding); +// Creates constant value instructions of the given shape. The literal must be a +// scalar shape and is broadcast to the given shape. +HloInstruction* CreateConstant(const Shape& shape, Literal value, + SpmdBuilder* b); // Creates zero value instructions of the given shape. HloInstruction* CreateZero(const Shape& shape, SpmdBuilder* b); +// Creates one value instructions of the given shape. +HloInstruction* CreateOne(const Shape& shape, SpmdBuilder* b); + template HloInstruction* CreateR0WithType(PrimitiveType type, NativeT value, SpmdBuilder* b) { @@ -87,10 +94,12 @@ Shape MakeNonPaddedShapeForGivenPartition(const Shape& shape, // Generates the HLO instructions that represent the dimension offsets on any // device. The size of the returned vector is the rank of the given shape. -std::vector MakePartitionOffsets(const Shape& shape, - const HloSharding& sharding, - HloInstruction* partition_id, - SpmdBuilder* b); +// If `dims` is non-empty, the generated offsets will only be non-zero for those +// dimensions. +std::vector MakePartitionOffsets( + const Shape& shape, const HloSharding& sharding, + HloInstruction* partition_id, SpmdBuilder* b, + absl::Span dims = {}); // Returns the offsets of the partition in the tile assignment. std::vector MakeTiledPartitionOrdinals( @@ -262,6 +271,106 @@ absl::optional GetKValueInTopKWhenPartitionSortDim(HloInstruction* hlo); HloInstruction* SliceFirstK(HloInstruction* hlo, SpmdBuilder* builder, int64 slice_dim, int64 k); +// Check if a dimension is sharded. +int64 ShardCountAtDim(const HloSharding& sharding, int64 dim); + +// Returns the list of source-target pairs of dimensions to swap during +// resharding via all-to-all. Reshard can be done by swapping each pair at a +// time. +absl::optional>> +GetReshardAllToAllSourceTargetDims(const HloSharding& source, + const HloSharding& target); + +// Returns whether the resharding can be done via collective-permute. +bool CanReshardWithCollectivePermute(const HloSharding& source, + const HloSharding& target); + +// Represents grouping devices in a tiled sharding along certain dimensions. +// Elements in group dimensions define different device groups, and the sharding +// represents the in-group sharding. +struct GroupedSharding { + GroupedSharding(std::vector> device_groups, + std::vector group_dims, + std::vector group_dim_sizes, int64 rank, + HloSharding grouped_sharding) + : device_groups(std::move(device_groups)), + group_dims(std::move(group_dims)), + group_dim_sizes(std::move(group_dim_sizes)), + sharding(std::move(grouped_sharding)) {} + std::vector> device_groups; + std::vector group_dims; + std::vector group_dim_sizes; + int64 rank; + HloSharding sharding; +}; + +// Creates a GroupedSharding for a tiled sharding. +GroupedSharding GroupShardingOnDims(const HloSharding& sharding, + absl::Span group_dims); + +// Reconstructs the ungrouped sharding from a GroupedSharding. +HloSharding UngroupSharding(const GroupedSharding& grouped_sharding); + +// Returns a new GroupedSharding that has the same group definition of +// `reference`. +GroupedSharding AlignGroupsWith(GroupedSharding grouped_sharding, + const GroupedSharding& reference, + bool ignore_group_order = false); + +// Returns the per-group base shape, i.e., before applying the in-group +// sharding. +Shape GetPerGroupBaseShape(const GroupedSharding& grouped_sharding, + const Shape& original_base_shape); + +// Creates the nested partitioner state for in-group patitioning. +PartitionedHlo::PartitioningState CreatePerGroupPartitioningState( + const PartitionedHlo::PartitioningState& state, + const std::vector>& device_groups, SpmdBuilder* b); + +// Partially shards a replicated HLO into groups along the group dimensions, and +// within each group data is still replicated. +HloInstruction* PerGroupSliceFromReplicated( + HloInstruction* replicated, HloInstruction* partition_id, + const std::vector>& device_groups, + absl::Span group_dims, absl::Span group_dim_sizes, + SpmdBuilder* b); + +// Similar to hlo_sharding_util::TransposeSharding(), but allows removing/adding +// non-partitioned dimensions. In src_to_tgt and tgt_to_src, -1 represents a +// non-existing dimension. +absl::optional TransposeShardingWithCollapsedDims( + const HloSharding& source, absl::Span src_to_tgt, + absl::Span tgt_to_src); + +// Returns the opcode if `reduction_comp` represents a simple binary elementwise +// computation on the two operands. +absl::optional ParseReductionComputation( + const HloComputation* reduction_comp); + +// Pad the shape from partial replicate shape for `dst_sharding`. +// If dst_sharding needs more padding and per_shard_size increased in +// dst_sharding, halo exchange on the right side is needed. +absl::optional PadFromPartialReplicateShape( + HloInstruction* hlo, const Shape& base_shape, + const HloSharding& src_sharding, const HloSharding& dst_sharding, + const std::vector& expand_tile_dims, + const SPMDCollectiveOpsCreator& collective_ops_creator, + int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b); + +// Get the compatible sharding from a partial replicate sharding to a given +// target tile dimensions. +// Compatible means replicate sharding can transform to the target tile +// dimensions by dynamic slice. +// For example, if partial_sharding is +// {devices=[1,2,2]0,1,2,3 last_tile_dim_replicate} +// Target tile dims is {2, 2}, the returned compatible sharding will be +// sharding={devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}. +// If patial replicate sharding is not partial replicate or can't reshard to +// target_tile_dims by dynamic slice, return absl::nullopt. +absl::optional PartialReplicateToTileCompatibleSharding( + const HloSharding& partial_sharding, + const std::vector& target_tile_dims); + } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/topk_rewriter.cc b/tensorflow/compiler/xla/service/topk_rewriter.cc new file mode 100644 index 00000000000..000b1e94ece --- /dev/null +++ b/tensorflow/compiler/xla/service/topk_rewriter.cc @@ -0,0 +1,194 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/topk_rewriter.h" + +#include "absl/algorithm/container.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +static bool IsNanSafeGt(HloComputation* comp) { + namespace m = match; + auto match_bitcast_f32 = [](int64 parameter_number) { + auto param = m::Parameter(parameter_number) + .WithShape(m::Shape().WithElementType(F32)); + auto param_s32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32)); + auto param_u32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32)); + return m::Select( + m::Lt(param_s32, m::ConstantScalar(0)), + m::BitcastConvert( + m::Subtract(m::ConstantScalar(std::numeric_limits::max()), + param_u32)) + .WithShape(m::Shape().WithElementType(S32)), + param_s32); + }; + auto match_bitcast_bf16 = [](int64 parameter_number) { + auto param = m::Convert(m::Parameter(parameter_number) + .WithShape(m::Shape().WithElementType(BF16))) + .WithShape(m::Shape().WithElementType(F32)); + auto param_s32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(S32)); + auto param_u32 = + m::BitcastConvert(param).WithShape(m::Shape().WithElementType(U32)); + return m::Select( + m::Lt(param_s32, m::ConstantScalar(0)), + m::BitcastConvert( + m::Subtract(m::ConstantScalar(std::numeric_limits::max()), + param_u32)) + .WithShape(m::Shape().WithElementType(S32)), + param_s32); + }; + return Match(comp->root_instruction(), + m::Gt(match_bitcast_f32(0), match_bitcast_f32(1))) || + Match(comp->root_instruction(), + m::Gt(match_bitcast_bf16(0), match_bitcast_bf16(1))); +} + +StatusOr TopkRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->computations()) { + for (HloInstruction* inst : comp->MakeInstructionPostOrder()) { + HloSortInstruction* sort = DynCast(inst); + if (sort == nullptr || sort->operand_count() != 2) { + continue; + } + HloInstruction* data = sort->mutable_operand(0); + HloIotaInstruction* iota = + DynCast(sort->mutable_operand(1)); + const PrimitiveType element_type = data->shape().element_type(); + if ((data->shape().rank() != 1 && data->shape().rank() != 2) || + (element_type != F32 && element_type != BF16)) { + continue; + } + if (iota == nullptr || iota->shape().rank() != data->shape().rank() || + iota->shape().element_type() != S32 || + iota->opcode() != HloOpcode::kIota || + iota->iota_dimension() != sort->sort_dimension()) { + continue; + } + if (!IsNanSafeGt(sort->to_apply())) { + continue; + } + const int64 sort_dim = sort->sort_dimension(); + const int64 batch_dim = sort_dim == 1 ? 0 : 1; + const bool has_batch = data->shape().rank() == 2; + + bool supported = true; + absl::optional k; + for (HloInstruction* gte : sort->users()) { + if (gte->opcode() != HloOpcode::kGetTupleElement || + gte->user_count() != 1) { + supported = false; + break; + } + const HloInstruction* slice = gte->users()[0]; + if (slice->opcode() != HloOpcode::kSlice) { + // Non-slice user means we are not doing a TopK + supported = false; + break; + } + if (absl::c_any_of(slice->slice_starts(), + [](int x) { return x != 0; }) || + absl::c_any_of(slice->slice_strides(), + [](int x) { return x != 1; })) { + // Strided slice or slicing at the beginning isn't supported. + supported = false; + break; + } + if (has_batch && slice->slice_limits(batch_dim) != + slice->operand(0)->shape().dimensions(batch_dim)) { + // Slicing along the batch dimension isn't supported. + supported = false; + break; + } + if (k == absl::nullopt) { + k = slice->slice_limits(sort_dim); + } else if (k != slice->slice_limits(sort_dim)) { + // Different k for the different operands isn't supported. + supported = false; + break; + } + } + if (k == absl::nullopt || !supported) { + continue; + } + + // Profitability check. + if (!is_profitable_to_convert_(sort, *k)) { + continue; + } + + const int64 batch_size = + has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1; + const int64 input_size = sort->operand(0)->shape().dimensions(sort_dim); + HloInstruction* input = sort->mutable_operand(0); + if (has_batch && sort_dim == 0) { + input = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, + {1, 0})); + } + + Shape topk_shape = + has_batch ? ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(element_type, + {batch_size, k.value()}), + ShapeUtil::MakeShape(S32, {batch_size, k.value()})}) + : ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(element_type, {k.value()}), + ShapeUtil::MakeShape(S32, {k.value()})}); + HloInstruction* topk = comp->AddInstruction( + HloInstruction::CreateCustomCall(topk_shape, {input}, "TopK")); + HloInstruction* value_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(0), topk, 0)); + HloInstruction* index_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(1), topk, 1)); + + if (has_batch && sort_dim == 0) { + value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), + value_gte, {1, 0})); + index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, + {1, 0})); + } + + for (HloInstruction* gte : sort->users()) { + for (HloInstruction* slice : gte->users()) { + if (gte->tuple_index() == 0) { + TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(value_gte)); + } else if (gte->tuple_index() == 1) { + TF_RETURN_IF_ERROR(slice->ReplaceAllUsesWith(index_gte)); + } else { + LOG(FATAL) << "Sort with more than 2 output isn't supported in " + "topk rewriter"; + } + } + } + changed = true; + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/topk_rewriter.h b/tensorflow/compiler/xla/service/topk_rewriter.h new file mode 100644 index 00000000000..68f8a8145e2 --- /dev/null +++ b/tensorflow/compiler/xla/service/topk_rewriter.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +// This pass pattern-matches soups of HLOs executing a TopK operation and +// replaces them with a TopK CustomCall when the given values are supported by +// the CustomCall and it is more efficient to use that implementation. +class TopkRewriter : public HloModulePass { + public: + explicit TopkRewriter(std::function + is_profitable_to_convert) + : is_profitable_to_convert_(std::move(is_profitable_to_convert)) {} + + absl::string_view name() const override { return "topk-rewriter"; } + + StatusOr Run(HloModule* module) override; + + private: + // Predicate that returns true if a sort instruction is profitable to be + // converted into a custom call. + std::function + is_profitable_to_convert_; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TOPK_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/topk_rewriter_test.cc b/tensorflow/compiler/xla/service/topk_rewriter_test.cc new file mode 100644 index 00000000000..ec5b34b1c0a --- /dev/null +++ b/tensorflow/compiler/xla/service/topk_rewriter_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/topk_rewriter.h" + +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +using TopkRewriterTest = HloTestBase; + +std::string getComparator() { + return R"( +%compare { + %p.1.lhs.8 = s32[] parameter(2) + %p.1.rhs.9 = s32[] parameter(3) + %p.0.lhs.6 = f32[] parameter(0) + %bitcast-convert.11 = s32[] bitcast-convert(%p.0.lhs.6) + %constant.15 = s32[] constant(0) + %compare.16 = pred[] compare(%bitcast-convert.11, %constant.15), direction=LT + %constant.10 = u32[] constant(2147483647) + %bitcast-convert.12 = u32[] bitcast-convert(%p.0.lhs.6) + %subtract.13 = u32[] subtract(%constant.10, %bitcast-convert.12) + %bitcast-convert.14 = s32[] bitcast-convert(%subtract.13) + %select.17 = s32[] select(%compare.16, %bitcast-convert.14, + %bitcast-convert.11) + %p.0.rhs.7 = f32[] parameter(1) + %bitcast-convert.19 = s32[] bitcast-convert(%p.0.rhs.7) + %constant.23 = s32[] constant(0) + %compare.24 = pred[] compare(%bitcast-convert.19, %constant.23), direction=LT + %constant.18 = u32[] constant(2147483647) + %bitcast-convert.20 = u32[] bitcast-convert(%p.0.rhs.7) + %subtract.21 = u32[] subtract(%constant.18, %bitcast-convert.20) + %bitcast-convert.22 = s32[] bitcast-convert(%subtract.21) + %select.25 = s32[] select(%compare.24, %bitcast-convert.22, + %bitcast-convert.19) + ROOT %compare.26 = pred[] compare(%select.17, %select.25), direction=GT +})"; +} + +TEST_F(TopkRewriterTest, Rewrite) { + const std::string hlo_string = R"( +HloModule module +)" + getComparator() + R"( +ENTRY cluster { + %arg_tuple.1 = f32[8,1234567] parameter(0) + %iota.4 = s32[8,1234567] iota(), iota_dimension=1 + %sort.27 = (f32[8,1234567], s32[8,1234567]) sort(%arg_tuple.1, %iota.4), + dimensions={1}, is_stable=true, to_apply=%compare + %get-tuple-element.28 = f32[8,1234567] get-tuple-element(%sort.27), index=0 + %slice.29 = f32[8,5] slice(%get-tuple-element.28), slice={[0:8], [0:5]} + %get-tuple-element.30 = s32[8,1234567] get-tuple-element(%sort.27), index=1 + %slice.31 = s32[8,5] slice(%get-tuple-element.30), slice={[0:8], [0:5]} + ROOT %tuple.32 = (f32[8,5], s32[8,5]) tuple(%slice.29, %slice.31) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TopkRewriter rewriter([](const HloSortInstruction*, int64) { return true; }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); + TF_ASSERT_OK(HloDCE().Run(module.get()).status()); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::GetTupleElement(op::CustomCall(op::Parameter(0)), 0), + op::GetTupleElement(op::CustomCall(op::Parameter(0)), 1))); + const HloInstruction* cc = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + EXPECT_THAT(cc->custom_call_target(), "TopK"); +} + +TEST_F(TopkRewriterTest, RewriteUnbatched) { + const std::string hlo_string = R"( +HloModule module +)" + getComparator() + R"( +ENTRY cluster { + %arg_tuple.1 = f32[1234567] parameter(0) + %iota.4 = s32[1234567] iota(), iota_dimension=0 + %sort.27 = (f32[1234567], s32[1234567]) sort(%arg_tuple.1, %iota.4), + dimensions={0}, is_stable=true, to_apply=%compare + %get-tuple-element.28 = f32[1234567] get-tuple-element(%sort.27), index=0 + %slice.29 = f32[5] slice(%get-tuple-element.28), slice={[0:5]} + %get-tuple-element.30 = s32[1234567] get-tuple-element(%sort.27), index=1 + %slice.31 = s32[5] slice(%get-tuple-element.30), slice={[0:5]} + ROOT %tuple.32 = (f32[5], s32[5]) tuple(%slice.29, %slice.31) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TopkRewriter rewriter([](const HloSortInstruction*, int64) { return true; }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); + TF_ASSERT_OK(HloDCE().Run(module.get()).status()); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::GetTupleElement(op::CustomCall(op::Parameter(0)), 0), + op::GetTupleElement(op::CustomCall(op::Parameter(0)), 1))); + const HloInstruction* cc = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + EXPECT_THAT(cc->custom_call_target(), "TopK"); +} + +TEST_F(TopkRewriterTest, RewriteTranspose) { + const std::string hlo_string = R"( +HloModule module +)" + getComparator() + R"( +ENTRY cluster { + %arg_tuple.1 = f32[1234567,8] parameter(0) + %iota.4 = s32[1234567,8] iota(), iota_dimension=0 + %sort.27 = (f32[1234567,8], s32[1234567,8]) sort(%arg_tuple.1, %iota.4), + dimensions={0}, is_stable=true, to_apply=%compare + %get-tuple-element.28 = f32[1234567,8] get-tuple-element(%sort.27), index=0 + %slice.29 = f32[5,8] slice(%get-tuple-element.28), slice={[0:5], [0:8]} + %get-tuple-element.30 = s32[1234567,8] get-tuple-element(%sort.27), index=1 + %slice.31 = s32[5,8] slice(%get-tuple-element.30), slice={[0:5], [0:8]} + ROOT %tuple.32 = (f32[5,8], s32[5,8]) tuple(%slice.29, %slice.31) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TopkRewriter rewriter([](const HloSortInstruction*, int64) { return true; }); + TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); + TF_ASSERT_OK(HloDCE().Run(module.get()).status()); + EXPECT_TRUE(changed); + LOG(INFO) << module->entry_computation()->ToString(); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::Transpose(op::GetTupleElement( + op::CustomCall(op::Transpose(op::Parameter(0))), 0)), + op::Transpose(op::GetTupleElement( + op::CustomCall(op::Transpose(op::Parameter(0))), 1)))); + const HloInstruction* cc = module->entry_computation() + ->root_instruction() + ->operand(0) + ->operand(0) + ->operand(0); + EXPECT_THAT(cc->custom_call_target(), "TopK"); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.cc b/tensorflow/compiler/xla/service/triangular_solve_expander.cc index cc483c310e8..d54eb9e78c3 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.cc +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.cc @@ -454,6 +454,9 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, } // namespace +TriangularSolveExpander::TriangularSolveExpander(int64 block_size) + : block_size_(block_size) {} + bool TriangularSolveExpander::InstructionMatchesPattern( HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kTriangularSolve; @@ -496,7 +499,7 @@ StatusOr TriangularSolveExpander::ExpandInstruction( BuildTriangularSolve(a, b, options.left_side(), options.lower(), transpose_a, conjugate_a, options.unit_diagonal(), - /*block_size=*/128, + /*block_size=*/block_size_, /*precision=*/PrecisionConfig::HIGHEST); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander.h b/tensorflow/compiler/xla/service/triangular_solve_expander.h index be2374ef8c8..362e8557229 100644 --- a/tensorflow/compiler/xla/service/triangular_solve_expander.h +++ b/tensorflow/compiler/xla/service/triangular_solve_expander.h @@ -23,6 +23,8 @@ namespace xla { class TriangularSolveExpander : public OpExpanderPass { public: + explicit TriangularSolveExpander(int64 block_size = 128); + absl::string_view name() const override { return "triangular_solve_expander"; } @@ -34,6 +36,8 @@ class TriangularSolveExpander : public OpExpanderPass { HloInstruction* instruction) override; private: + // Block size for BuildTriangularSolve + const int64 block_size_; // Mapping from op signatures to existing computations. absl::flat_hash_map computation_cache_; }; diff --git a/tensorflow/compiler/xla/service/triangular_solve_expander_test.cc b/tensorflow/compiler/xla/service/triangular_solve_expander_test.cc new file mode 100644 index 00000000000..84663f80e7a --- /dev/null +++ b/tensorflow/compiler/xla/service/triangular_solve_expander_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/triangular_solve_expander.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class TriangularExpanderTest : public HloTestBase, + public ::testing::WithParamInterface {}; + +TEST_P(TriangularExpanderTest, TestBlockSize) { + auto block_size = GetParam(); + std::string hlo_string = R"( + HloModule TensorFlowTriangularSolve + + ENTRY main { + a = f32[256,256]{1,0} parameter(0) + b = f32[256,192]{1,0} parameter(1) + ROOT triangular-solve = f32[256,192]{1,0} triangular-solve(a, b), + left_side=true, unit_diagonal=true, + lower=true, transpose_a=NO_TRANSPOSE + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + { + TriangularSolveExpander triangular_solve_expander(block_size); + + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&triangular_solve_expander, module.get())); + EXPECT_TRUE(result); + } + + // To test triangular solver expander we generate simple bidiagonal matrix: + // Solve a * x = b. + // Check that shape is still valid. + // Use reference matrix multiplication to test validity of result. + + Array2D a(256, 256); + for (int64 row = 0; row < a.dim(0); ++row) { + a(row, row) = 1; + if (row > 0) { + a(row, row - 1) = 0.01; + } + } + + Array2D b(256, 192); + const float kMax = static_cast(b.dim(0) * b.dim(1) + 1); + for (int64 row = 0; row < b.dim(0); ++row) { + for (int64 col = 0; col < b.dim(1); ++col) { + b(row, col) = static_cast(row + col + 1) / kMax; + } + } + auto la = LiteralUtil::CreateR2FromArray2D(a); + auto lb = LiteralUtil::CreateR2FromArray2D(b); + + TF_ASSERT_OK_AND_ASSIGN(Literal lx, Execute(std::move(module), {&la, &lb})); + + auto x_shape = lx.shape(); + EXPECT_EQ(x_shape.dimensions_size(), 2); + EXPECT_EQ(x_shape.dimensions(0), b.dim(0)); + EXPECT_EQ(x_shape.dimensions(1), b.dim(1)); + + Array2D x(x_shape.dimensions(0), x_shape.dimensions(1)); + x.SetValues(lx.data()); + + auto ref_b = ReferenceUtil::MatmulArray2D(a, x); + auto ref_lb = LiteralUtil::CreateR2FromArray2D(*ref_b); + + EXPECT_TRUE( + LiteralTestUtil::NearOrEqual(ref_lb, lb, ErrorSpec{0.001, 0.001})); +} + +// block_size test limits based on the following considerations: +// - test at least twice the range of original value +// - try to test odd values unaligned with matrix dims +// - full 1-256 range test takes too long to run + +INSTANTIATE_TEST_CASE_P(TriangularExpanderTestInstances, TriangularExpanderTest, + ::testing::Range(2, 256, 7)); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h index dfaac677724..6a19a1fac09 100644 --- a/tensorflow/compiler/xla/shape.h +++ b/tensorflow/compiler/xla/shape.h @@ -49,7 +49,7 @@ class Shape { // Returns the rank (number of dimensions) of the given shape. Shape must be // an array. int64 rank() const { - CHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); + DCHECK(IsArray()) << "Non-arrays do not have a rank, shape: " << ToString(); return dimensions_.size(); } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index bce40578132..0833919b124 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -339,6 +339,15 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( TF_DCHECK_OK(ValidateShape(*shape)); } +/* static */ void ShapeUtil::CopyDynamicDimensions(Shape* to, + const Shape& from) { + CHECK_EQ(to->rank(), from.rank()); + for (int64 i = 0; i < from.rank(); ++i) { + to->set_dynamic_dimension(i, from.is_dynamic_dimension(i)); + } + TF_DCHECK_OK(ValidateShape(*to)); +} + /* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) { return primitive_util::IsIntegralType(shape.element_type()); } @@ -522,13 +531,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( text += ")"; return text; } - string result = StrCat( - primitive_util::LowercasePrimitiveTypeName(shape.element_type()), "["); - for (int i = 0; i < shape.dimensions().size(); i++) { - StrAppend(&result, (i > 0) ? "," : "", - shape.is_dynamic_dimension(i) ? "<=" : "", shape.dimensions(i)); - } - result += "]"; + string result = HumanString(shape); if (IsScalar(shape)) { string layout_str = LayoutUtil::HumanString(shape.layout()); // Don't print "{}" as layout for scalars. @@ -780,9 +783,18 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( /* static */ Shape ShapeUtil::ChangeElementType(const Shape& original, PrimitiveType type) { - Shape new_shape = original; - new_shape.set_element_type(type); - return new_shape; + if (original.IsTuple()) { + std::vector new_operands; + new_operands.reserve(original.tuple_shapes_size()); + for (const Shape& operand : original.tuple_shapes()) { + new_operands.push_back(ChangeElementType(operand, type)); + } + return MakeTupleShape(new_operands); + } else { + Shape new_shape = original; + new_shape.set_element_type(type); + return new_shape; + } } /* static */ bool ShapeUtil::IndexIsValid(const Shape& shape, diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index fe1a8acf6e4..3f69a8b0aca 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -269,6 +269,14 @@ class ShapeUtil { if (SameElementType(a, b)) { return a.element_type(); } + // If only one of A and B are floating use the floating point type. + if (ElementIsFloating(a) && !ElementIsFloating(b)) { + return a.element_type(); + } + if (ElementIsFloating(b) && !ElementIsFloating(a)) { + return b.element_type(); + } + // Use the higher precision type. return primitive_util::BitWidth(a.element_type()) < primitive_util::BitWidth(b.element_type()) ? b.element_type() @@ -377,6 +385,9 @@ class ShapeUtil { // Appends a major dimension to the shape with the given bound. static void AppendMajorDimension(int bound, Shape* shape); + // Copy the dynamic dimensions property from one shape to another. + static void CopyDynamicDimensions(Shape* to, const Shape& from); + // Returns an empty tuple shape. Can be used as a sentinel Shape value. static Shape MakeNil() { return MakeTupleShape({}); } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 83851fabd53..3dac381ae7d 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -728,6 +728,7 @@ xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], shard_count = 25, + tags = ["no_oss"], # b/163416869 deps = [ ":test_macros_header", "//tensorflow/compiler/xla:array2d", @@ -1115,10 +1116,53 @@ xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], - shard_count = 40, + shard_count = 50, tags = [ "no_rocm", + "optonly", + # Timed out on 2020-07-18 "nozapfhahn", + ], + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +xla_test( + name = "convolution_test_1d", + timeout = "long", + srcs = ["convolution_test_1d.cc"], + # Turn on logging so that VLOG statements don't appear uncovered to zapfhahn. + args = ["--vmodule=convolution_emitter=7"], + # In the open source build, convolution_test_1d_gpu fails because it doesn't + # recognize --vmodule. + disabled_backends = [ + "cpu", + "gpu", + ], + shard_count = 50, + tags = [ + "no_rocm", + "optonly", + ], + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + +xla_test( + name = "convolution_test_1d_no_vmodule", + timeout = "long", + srcs = ["convolution_test_1d.cc"], + backends = [ + "cpu", + "gpu", + ], + shard_count = 50, + tags = [ + "no_rocm", "optonly", ], deps = CONVOLUTION_TEST_DEPS + [ @@ -1147,6 +1191,23 @@ xla_test( ], ) +xla_test( + name = "convolution_test_1d_autotune_disabled", + timeout = "long", + srcs = ["convolution_test_1d.cc"], + args = ["--xla_gpu_autotune_level=0"], + backends = ["gpu"], + shard_count = 40, + tags = [ + "no_rocm", + "optonly", + ], + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + xla_test( name = "convolution_test_gpu_alternative_layout", timeout = "long", @@ -1163,6 +1224,22 @@ xla_test( ], ) +xla_test( + name = "convolution_test_1d_gpu_alternative_layout", + timeout = "long", + srcs = ["convolution_test_1d.cc"], + backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, + backends = ["gpu"], + shard_count = 25, + tags = [ + "no_rocm", + ], + deps = CONVOLUTION_TEST_DEPS + [ + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], +) + xla_test( name = "convolution_variants_test", timeout = "long", @@ -2012,6 +2089,31 @@ xla_test( ], ) +xla_test( + name = "dynamism_inference_test", + srcs = ["dynamism_inference_test.cc"], + deps = [ + ":test_macros_header", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:prng", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/strings", + ], +) + xla_test( name = "compute_constant_test", srcs = ["compute_constant_test.cc"], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index a956b85a940..fdc679a61c6 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -1203,6 +1203,16 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32sTO) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto lhs = ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 2.25f, NAN, NAN}); + EqTotalOrder(lhs, rhs); + + ComputeAndCompareR1(&builder, {false, false, true, true, false}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) { XlaBuilder builder(TestName()); auto lhs = ConstantR1(&builder, {}); @@ -1222,6 +1232,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32s) { ComputeAndCompareR1(&builder, {false, true, true, false, false}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, CompareGeF32sTO) { + SetFastMathDisabled(true); + XlaBuilder builder(TestName()); + auto lhs = + ConstantR1(&builder, {-2.5f, 25.5f, 2.25f, NAN, 6.0f, 6.0f}); + auto rhs = ConstantR1(&builder, {10.0f, 5.0f, 1.0f, 10.0f, NAN, -NAN}); + GeTotalOrder(lhs, rhs); + + ComputeAndCompareR1(&builder, {false, true, true, true, false, true}, + {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareGtF32s) { SetFastMathDisabled(true); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/buffer_donation_test.cc b/tensorflow/compiler/xla/tests/buffer_donation_test.cc index 5f936870103..f78083fe2af 100644 --- a/tensorflow/compiler/xla/tests/buffer_donation_test.cc +++ b/tensorflow/compiler/xla/tests/buffer_donation_test.cc @@ -61,7 +61,7 @@ class BufferDonationTest : public HloTestBase { absl::Span argument_literals, absl::Span donate_arguments, absl::Span expected_runtime_aliasing, - const Literal& expected) { + const Literal& expected, std::string expected_failure = "") { // Create a copy of the output shape because the HLO module is std::moved // into the compiler and may be deallocated. const Shape output_shape = hlo_module->result_shape(); @@ -119,13 +119,23 @@ class BufferDonationTest : public HloTestBase { } }); - args.emplace_back(ExecutionInput(std::move(owned_buffers))); + args.emplace_back( + ExecutionInput(std::move(owned_buffers), argument_literal.shape())); } - TF_ASSERT_OK_AND_ASSIGN( - ExecutionOutput output, + StatusOr output_status = executable->ExecuteAsyncOnStream(&service_run_options, std::move(args), - /*hlo_execution_profile=*/nullptr)); + /*hlo_execution_profile=*/nullptr); + if (!expected_failure.empty()) { + ASSERT_FALSE(output_status.ok()); + ASSERT_TRUE(absl::StrContains(output_status.status().error_message(), + expected_failure)) + << "got: \n" + << output_status.status().error_message() << " \nvs want\n" + << expected_failure; + return; + } + ExecutionOutput output = output_status.ConsumeValueOrDie(); se::DeviceMemoryBase result_root_buffer = output.Result().root_buffer(); LOG(INFO) << "result allocation = " << result_root_buffer.opaque() @@ -302,5 +312,37 @@ ENTRY entry { #endif } +TEST_F(BufferDonationTest, TestMustAliasNotDonated) { + HloModuleConfig config; + + StatusOr> module = + ParseAndReturnVerifiedModule(R"( +HloModule module + +ENTRY entry { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT out = (f32[], f32[]) tuple(a, b) +} + )", + config); + + TF_ASSERT_OK(module->get()->input_output_alias_config().SetUpAlias( + {0}, 0, {}, HloInputOutputAliasConfig::kMustAlias)); + + std::vector args; + args.push_back(LiteralUtil::CreateR0(0.1)); + args.push_back(LiteralUtil::CreateR0(0.2)); + Literal expected = LiteralUtil::MakeTupleFromSlices( + {LiteralUtil::CreateR0(0.1), LiteralUtil::CreateR0(0.2)}); + +#ifndef XLA_TEST_BACKEND_INTERPRETER + RunAndCheck(std::move(*module), args, + /*donate_arguments=*/{false, false}, {true, false}, expected, + "An input was configured to be must-alias at " + "compile time but not donated at runtime:"); +#endif +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index 7459b3d3f1f..ed5fabb663e 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -568,7 +568,8 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { ten = u32[] constant(10) sum = u32[] add(replica, ten) p = u32[2] broadcast(sum), dimensions={} - ROOT permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}, {2,2}} + permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}, {2,2}} + ROOT copy = u32[2] copy(permute) } )"; const int64 kNumReplicas = 4; diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index c63f1d0edf3..8021d6fe5db 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Tests of convolution with trivial kernels and no special variations (like +// Tests of 2+D convolution with trivial kernels and no special variations (like // strides and padding). #include @@ -240,174 +240,6 @@ class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes); TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); } -XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { - XlaBuilder builder(TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = Parameter(&builder, 0, input_shape, "input"); - auto filter = Parameter(&builder, 1, filter_shape, "filter"); - Conv(input, filter, {1}, Padding::kValid); - } - - Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); - Array3D filter({{{10, 20}, {30, 40}}}); - - Array3D expected({{{510, 610, 710, 810}}}); - - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); -} - -template -class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { - public: - void RunTest() { - XlaBuilder builder(TestName()); - { - Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); - Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); - auto input = Parameter(&builder, 0, input_shape, "input"); - auto filter = Parameter(&builder, 1, filter_shape, "filter"); - // Convolution dimensions are bf0_oi0->bo0. - ConvGeneralDilated( - input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, - /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2}, - /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); - } - - Array3D input( - {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}}); - Array3D filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}}); - - Array3D expected({{{570.0f, 670.0f, 770.0f}}}); - - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); - } -}; // namespace - -TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes); -TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); } - -XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { - XlaBuilder builder(TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = Parameter(&builder, 0, input_shape, "input"); - auto filter = Parameter(&builder, 1, filter_shape, "filter"); - // Convolution dimensions are bf0_oi0->bo0. - ConvGeneralDilated( - input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, - /*lhs_dilation=*/{2}, /*rhs_dilation=*/{1}, - /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); - } - - Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); - Array3D filter({{{10, 20}, {30, 40}}}); - - Array3D expected({{{190, 320, 230, 380, 270, 440, 310, 500}}}); - - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); -} - -XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { - XlaBuilder builder(TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = Parameter(&builder, 0, input_shape, "input"); - auto filter = Parameter(&builder, 1, filter_shape, "filter"); - // Convolution dimensions are bf0_oi0->bo0. - ConvGeneralDilated( - input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, - /*lhs_dilation=*/{2}, /*rhs_dilation=*/{2}, - /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); - } - - Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); - Array3D filter({{{10, 20}, {30, 40}}}); - - Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); - - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); -} - -template -class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { - public: - void RunTest() { - XlaBuilder builder(TestName()); - { - Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); - Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); - auto input = Parameter(&builder, 0, input_shape, "input"); - auto filter = Parameter(&builder, 1, filter_shape, "filter"); - // Convolution dimensions are bf0_oi0->bo0. - ConvGeneralDilated( - input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}}, - /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1}, - /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); - } - - Array3D input( - {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}}); - Array3D filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}}); - - Array3D expected( - {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); - - auto input_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); - } -}; - -TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes); -TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); } - XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { XlaBuilder builder(TestName()); std::vector input_dims = {1, 4, 2, 3, 3}; @@ -1714,150 +1546,7 @@ INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation, ConvolveWithAndWithoutCanonicalization, ::testing::Values(true, false)); -struct Convolve1DTestParam { - int64 input_feature; - int64 output_feature; - int64 batch; - int64 window_size; - int64 num_windows; -}; -class Convolve1D1WindowTestBase - : public ConvolutionTest, - public ::testing::WithParamInterface { - protected: - template - void TestImpl() { - XlaBuilder builder(TestName()); - int64 input_feature = GetParam().input_feature; - int64 output_feature = GetParam().output_feature; - int64 batch = GetParam().batch; - int64 num_windows = GetParam().num_windows; - int64 window_size = GetParam().window_size; - std::vector input_dims = {batch, window_size + num_windows - 1, - input_feature}; - std::vector filter_dims = {window_size, input_feature, - output_feature}; - Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); - Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); - { - auto input = Parameter(&builder, 0, input_shape, "input"); - auto filter = Parameter(&builder, 1, filter_shape, "filter"); - - // Tensorflow dimension numbers for 1D convolution. - ConvolutionDimensionNumbers dnums; - dnums.set_input_batch_dimension(0); - dnums.set_output_batch_dimension(0); - dnums.add_input_spatial_dimensions(1); - dnums.add_output_spatial_dimensions(1); - dnums.set_input_feature_dimension(2); - dnums.set_output_feature_dimension(2); - dnums.add_kernel_spatial_dimensions(0); - dnums.set_kernel_input_feature_dimension(1); - dnums.set_kernel_output_feature_dimension(2); - - ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums); - } - - std::vector input_elems(ShapeUtil::ElementsIn(input_shape), - static_cast(1.0f)); - auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); - - std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), - static_cast(1.0f)); - - auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); - - std::vector expect_elems(batch * output_feature * num_windows, - static_cast(window_size * input_feature)); - auto expected_r1 = LiteralUtil::CreateR1(expect_elems); - auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature}) - .ConsumeValueOrDie(); - - auto input_literal = - client_->TransferToServer(input_r3).ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(filter_r3).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, expected_r3, - {input_literal.get(), filter_literal.get()}, - error_spec_); - } -}; - -class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {}; - -XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl(); } - -INSTANTIATE_TEST_CASE_P( - Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat, - ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2}, - Convolve1DTestParam{160, 1, 1, 5, 1}, - Convolve1DTestParam{24, 1, 1, 20, 1}, - Convolve1DTestParam{30, 1, 1, 20, 1}, - Convolve1DTestParam{23, 1, 1, 20, 20}, - Convolve1DTestParam{25, 1, 1, 20, 1}, - Convolve1DTestParam{24, 1, 1, 10, 5}, - Convolve1DTestParam{160, 1, 1, 10, 1}, - Convolve1DTestParam{255, 1, 1, 3, 1}, - Convolve1DTestParam{130, 1, 1, 1, 2}, - Convolve1DTestParam{136, 1, 1, 1, 2}, - Convolve1DTestParam{64, 1, 1, 1, 1}, - Convolve1DTestParam{128, 1, 1, 1, 1}, - Convolve1DTestParam{139, 1, 1, 128, 1}, - Convolve1DTestParam{1, 10, 10, 1, 10}, - Convolve1DTestParam{1, 10, 130, 1, 2}, - Convolve1DTestParam{1, 10, 130, 1, 1}, - Convolve1DTestParam{1, 64, 64, 1, 10}, - Convolve1DTestParam{1, 65, 65, 1, 1}, - Convolve1DTestParam{1, 128, 128, 1, 1}, - Convolve1DTestParam{128, 128, 128, 128, 1}, - Convolve1DTestParam{1, 128, 128, 1, 1}, - Convolve1DTestParam{2, 2, 2, 2, 1}, - Convolve1DTestParam{161, 1, 1, 10, 1}, - Convolve1DTestParam{900, 1, 1, 10, 1}, - Convolve1DTestParam{640, 3, 3, 128, 1}) - -); - -#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU) -class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {}; - -XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) { - TestImpl(); -} - -INSTANTIATE_TEST_CASE_P( - Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf, - ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2}, - Convolve1DTestParam{160, 1, 1, 5, 1}, - Convolve1DTestParam{24, 1, 1, 20, 1}, - Convolve1DTestParam{30, 1, 1, 20, 1}, - Convolve1DTestParam{23, 1, 1, 20, 20}, - Convolve1DTestParam{25, 1, 1, 20, 1}, - Convolve1DTestParam{24, 1, 1, 10, 5}, - Convolve1DTestParam{160, 1, 1, 10, 1}, - Convolve1DTestParam{255, 1, 1, 3, 1}, - Convolve1DTestParam{130, 1, 1, 1, 3}, - Convolve1DTestParam{64, 1, 1, 1, 1}, - Convolve1DTestParam{128, 1, 1, 1, 1}, - Convolve1DTestParam{139, 1, 1, 128, 1}, - Convolve1DTestParam{640, 3, 3, 128, 1}, - Convolve1DTestParam{900, 1, 1, 10, 1}, - Convolve1DTestParam{1, 10, 10, 1, 10}, - Convolve1DTestParam{1, 10, 130, 1, 1}, - Convolve1DTestParam{1, 10, 130, 1, 2}, - Convolve1DTestParam{1, 64, 64, 1, 10}, - Convolve1DTestParam{1, 65, 65, 1, 1}, - Convolve1DTestParam{1, 128, 128, 1, 1}, - Convolve1DTestParam{128, 128, 128, 128, 1}, - Convolve1DTestParam{1, 128, 128, 1, 1}, - Convolve1DTestParam{2, 2, 2, 2, 1}, - Convolve1DTestParam{161, 1, 1, 10, 1}) - -); -#endif XLA_TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/convolution_test_1d.cc b/tensorflow/compiler/xla/tests/convolution_test_1d.cc new file mode 100644 index 00000000000..2b2bf098145 --- /dev/null +++ b/tensorflow/compiler/xla/tests/convolution_test_1d.cc @@ -0,0 +1,376 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests of 1D convolution with trivial kernels and no special variations (like +// strides and padding). + +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class ConvolutionTest : public ClientLibraryTestBase { + protected: +#if XLA_TEST_BACKEND_GPU + // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial + // convolution. So relax the absolute error threshold. + ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-3); +#else + ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-3); +#endif +}; + +#ifdef XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16 +using TestTypes = ::testing::Types; +#else +using TestTypes = ::testing::Types; +#endif + +struct Convolve1DTestParam { + int64 input_feature; + int64 output_feature; + int64 batch; + int64 window_size; + int64 num_windows; +}; + +class Convolve1D1WindowTestBase + : public ConvolutionTest, + public ::testing::WithParamInterface { + protected: + template + void TestImpl() { + XlaBuilder builder(TestName()); + int64 input_feature = GetParam().input_feature; + int64 output_feature = GetParam().output_feature; + int64 batch = GetParam().batch; + int64 num_windows = GetParam().num_windows; + int64 window_size = GetParam().window_size; + std::vector input_dims = {batch, window_size + num_windows - 1, + input_feature}; + std::vector filter_dims = {window_size, input_feature, + output_feature}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + { + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 1D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.set_input_feature_dimension(2); + dnums.set_output_feature_dimension(2); + dnums.add_kernel_spatial_dimensions(0); + dnums.set_kernel_input_feature_dimension(1); + dnums.set_kernel_output_feature_dimension(2); + + ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, dnums); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1.0f)); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r3 = input_r1.Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(1.0f)); + + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r3 = filter_r1.Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector expect_elems(batch * output_feature * num_windows, + static_cast(window_size * input_feature)); + auto expected_r1 = LiteralUtil::CreateR1(expect_elems); + auto expected_r3 = expected_r1.Reshape({batch, num_windows, output_feature}) + .ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(input_r3).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(filter_r3).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, expected_r3, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {}; + +XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl(); } + +INSTANTIATE_TEST_CASE_P( + Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat, + ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2}, + Convolve1DTestParam{160, 1, 1, 5, 1}, + Convolve1DTestParam{24, 1, 1, 20, 1}, + Convolve1DTestParam{30, 1, 1, 20, 1}, + Convolve1DTestParam{23, 1, 1, 20, 20}, + Convolve1DTestParam{25, 1, 1, 20, 1}, + Convolve1DTestParam{24, 1, 1, 10, 5}, + Convolve1DTestParam{160, 1, 1, 10, 1}, + Convolve1DTestParam{255, 1, 1, 3, 1}, + Convolve1DTestParam{130, 1, 1, 1, 2}, + Convolve1DTestParam{136, 1, 1, 1, 2}, + Convolve1DTestParam{64, 1, 1, 1, 1}, + Convolve1DTestParam{128, 1, 1, 1, 1}, + Convolve1DTestParam{139, 1, 1, 128, 1}, + Convolve1DTestParam{1, 10, 10, 1, 10}, + Convolve1DTestParam{1, 10, 130, 1, 2}, + Convolve1DTestParam{1, 10, 130, 1, 1}, + Convolve1DTestParam{1, 64, 64, 1, 10}, + Convolve1DTestParam{1, 65, 65, 1, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{128, 128, 128, 128, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{2, 2, 2, 2, 1}, + Convolve1DTestParam{161, 1, 1, 10, 1}, + Convolve1DTestParam{900, 1, 1, 10, 1}, + Convolve1DTestParam{640, 3, 3, 128, 1}) + +); + +#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU) +class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {}; + +XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) { + TestImpl(); +} + +INSTANTIATE_TEST_CASE_P( + Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf, + ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2}, + Convolve1DTestParam{160, 1, 1, 5, 1}, + Convolve1DTestParam{24, 1, 1, 20, 1}, + Convolve1DTestParam{30, 1, 1, 20, 1}, + Convolve1DTestParam{23, 1, 1, 20, 20}, + Convolve1DTestParam{25, 1, 1, 20, 1}, + Convolve1DTestParam{24, 1, 1, 10, 5}, + Convolve1DTestParam{160, 1, 1, 10, 1}, + Convolve1DTestParam{255, 1, 1, 3, 1}, + Convolve1DTestParam{130, 1, 1, 1, 3}, + Convolve1DTestParam{64, 1, 1, 1, 1}, + Convolve1DTestParam{128, 1, 1, 1, 1}, + Convolve1DTestParam{139, 1, 1, 128, 1}, + Convolve1DTestParam{640, 3, 3, 128, 1}, + Convolve1DTestParam{900, 1, 1, 10, 1}, + Convolve1DTestParam{1, 10, 10, 1, 10}, + Convolve1DTestParam{1, 10, 130, 1, 1}, + Convolve1DTestParam{1, 10, 130, 1, 2}, + Convolve1DTestParam{1, 64, 64, 1, 10}, + Convolve1DTestParam{1, 65, 65, 1, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{128, 128, 128, 128, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{2, 2, 2, 2, 1}, + Convolve1DTestParam{161, 1, 1, 10, 1}) + +); +#endif + +XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { + XlaBuilder builder(TestName()); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + Conv(input, filter, {1}, Padding::kValid); + } + + Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); + Array3D filter({{{10, 20}, {30, 40}}}); + + Array3D expected({{{510, 610, 710, 810}}}); + + auto input_literal = + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR3(&builder, expected, + {input_literal.get(), filter_literal.get()}, + error_spec_); +} + +template +class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + { + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + // Convolution dimensions are bf0_oi0->bo0. + ConvGeneralDilated( + input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, + /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2}, + /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); + } + + Array3D input( + {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}}); + Array3D filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}}); + + Array3D expected({{{570.0f, 670.0f, 770.0f}}}); + + auto input_literal = + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR3(&builder, expected, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; // namespace + +TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes); +TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); } + +XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { + XlaBuilder builder(TestName()); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + // Convolution dimensions are bf0_oi0->bo0. + ConvGeneralDilated( + input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, + /*lhs_dilation=*/{2}, /*rhs_dilation=*/{1}, + /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); + } + + Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); + Array3D filter({{{10, 20}, {30, 40}}}); + + Array3D expected({{{190, 320, 230, 380, 270, 440, 310, 500}}}); + + auto input_literal = + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR3(&builder, expected, + {input_literal.get(), filter_literal.get()}, + error_spec_); +} + +XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { + XlaBuilder builder(TestName()); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + // Convolution dimensions are bf0_oi0->bo0. + ConvGeneralDilated( + input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, + /*lhs_dilation=*/{2}, /*rhs_dilation=*/{2}, + /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); + } + + Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); + Array3D filter({{{10, 20}, {30, 40}}}); + + Array3D expected({{{510, 0, 610, 0, 710, 0, 810}}}); + + auto input_literal = + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR3(&builder, expected, + {input_literal.get(), filter_literal.get()}, + error_spec_); +} + +template +class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { + public: + void RunTest() { + XlaBuilder builder(TestName()); + { + Shape input_shape = ShapeUtil::MakeShapeWithType({1, 2, 5}); + Shape filter_shape = ShapeUtil::MakeShapeWithType({1, 2, 2}); + auto input = Parameter(&builder, 0, input_shape, "input"); + auto filter = Parameter(&builder, 1, filter_shape, "filter"); + // Convolution dimensions are bf0_oi0->bo0. + ConvGeneralDilated( + input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}}, + /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1}, + /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); + } + + Array3D input( + {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}}); + Array3D filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}}); + + Array3D expected( + {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); + + auto input_literal = + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(LiteralUtil::CreateR3FromArray3D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR3(&builder, expected, + {input_literal.get(), filter_literal.get()}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes); +TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); } + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 26cb25acbfe..60ba27b2050 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -1429,19 +1429,137 @@ ENTRY main { EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); } -XLA_TEST_F(DotOperationTextTest, IntegerDotCodegen) { +XLA_TEST_F(DotOperationTextTest, S32IotaDot) { absl::string_view hlo_string = R"( HloModule SmallIntegerDot ENTRY SmallIntegerDot { - arg0 = s32[1,2,2] parameter(0) - arg1 = s32[1,2,1] parameter(1) - ROOT dot = s32[1,2,1] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + arg0 = s32[5,55,8] iota(), iota_dimension=1 + arg1 = s32[5,8,200] iota(), iota_dimension=2 + ROOT dot = s32[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} } )"; - EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{4e-3, 4e-3})); + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); +} + +XLA_TEST_F(DotOperationTextTest, S32IotaSquaredDot) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = s32[16,2] iota(), iota_dimension=0 + a = s32[16,2] multiply(arg0, arg0) + r = s32[16,2] multiply(a, a) + arg1 = s32[2,98] iota(), iota_dimension=1 + b = s32[2,98] multiply(arg1, arg1) + s = s32[2,98] multiply(b, b) + ROOT dot = s32[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(U16IotaDot)) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = u16[5,55,8] parameter(0) + arg1 = u16[5,8,200] parameter(1) + dot = u16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} + ROOT c = s32[5,55,200] convert(dot) +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(U16IotaSquaredDot)) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = u16[16,2] iota(), iota_dimension=0 + a = u16[16,2] multiply(arg0, arg0) + r = u16[16,2] multiply(a, a) + arg1 = u16[2,98] iota(), iota_dimension=1 + b = u16[2,98] multiply(arg1, arg1) + s = u16[2,98] multiply(b, b) + ROOT dot = u16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(S16IotaDot)) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = s16[5,55,8] iota(), iota_dimension=1 + arg1 = s16[5,8,200] iota(), iota_dimension=2 + ROOT dot = s16[5,55,200] dot(arg0, arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(S16IotaSquaredDot)) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = s16[16,2] iota(), iota_dimension=0 + a = s16[16,2] multiply(arg0, arg0) + r = s16[16,2] multiply(a, a) + arg1 = s16[2,98] iota(), iota_dimension=1 + b = s16[2,98] multiply(arg1, arg1) + s = s16[2,98] multiply(b, b) + ROOT dot = s16[16,98] dot(r, s), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); +} + +XLA_TEST_F(DotOperationTextTest, DISABLED_ON_CPU(S8Dot)) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = s8[20,2] parameter(0) + arg1 = s8[2,20] parameter(1) + ROOT dot = s8[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); +} + +XLA_TEST_F(DotOperationTextTest, S32Dot) { + absl::string_view hlo_string = + R"( +HloModule SmallIntegerDot + +ENTRY SmallIntegerDot { + arg0 = s32[20,55] parameter(0) + arg1 = s32[55,20] parameter(1) + ROOT dot = s32[20,20] dot(arg0, arg1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0, 0})); } XLA_TEST_F(DotOperationTextTest, GpuTransposeOutput) { diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc new file mode 100644 index 00000000000..ba4092def16 --- /dev/null +++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc @@ -0,0 +1,215 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/strings/match.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/client/lib/prng.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// An enumerator for the client types that we want to iterate over in +// the various tests. +enum class ClientType { kLocal, kCompileOnly }; +ClientType client_types[] = {ClientType::kLocal, ClientType::kCompileOnly}; + +class DynamismInferenceTest : public ::testing::Test { + public: + explicit DynamismInferenceTest(se::Platform* platform = nullptr) + : platform_(platform) {} + + string TestName() const { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + Client* ClientOrDie(se::Platform* platform, ClientType client_type) { + if (client_type == ClientType::kLocal) { + StatusOr result = + ClientLibrary::GetOrCreateLocalClient(platform); + TF_CHECK_OK(result.status()) + << "could not create LocalClient for testing"; + return result.ValueOrDie(); + } else if (client_type == ClientType::kCompileOnly) { + StatusOr result = + ClientLibrary::GetOrCreateCompileOnlyClient(platform); + TF_CHECK_OK(result.status()) + << "could not create CompileOnlyClient for testing"; + return result.ValueOrDie(); + } + LOG(FATAL) << "invalid client_type value"; + } + + StatusOr ComputeDynamismLiteral(Client* client, XlaOp operand, + XlaBuilder* builder, + Layout* output_layout = nullptr) { + TF_ASSIGN_OR_RETURN(auto subgraph, + builder->BuildDynamicInferenceGraph(operand)); + TF_ASSIGN_OR_RETURN(auto computed, + client->ComputeConstant(subgraph, output_layout)); + return std::move(computed); + } + + StatusOr ComputeDynamismScalar(Client* client, XlaOp operand, + XlaBuilder* builder, + ShapeIndex index = {}) { + TF_ASSIGN_OR_RETURN(auto literal, ComputeDynamismLiteral(client, operand, + builder, nullptr)); + return literal.Get({}, index); + } + + se::Platform* platform_; +}; + +TEST_F(DynamismInferenceTest, ScalarInt32Literal) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto computation = ConstantR0(&b, 42); + + auto value = ComputeDynamismScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + // A constant is not dynamic. + EXPECT_EQ(value.ValueOrDie(), false); + } +} + +TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + + auto tuple = Tuple(&b, {c, p}); + auto gte0 = GetTupleElement(tuple, 0); + auto gte1 = GetTupleElement(tuple, 1); + auto tuple_2 = Tuple(&b, {gte0, gte1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + + auto concat = ConcatScalars(&b, {c, p}); + auto slice0 = SliceInDim(concat, 0, 1, 1, 0); + auto reshape0 = Reshape(slice0, {}); + auto slice1 = SliceInDim(concat, 1, 2, 1, 0); + auto reshape1 = Reshape(slice1, {}); + auto tuple_2 = Tuple(&b, {reshape0, reshape1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, ParameterIsDynamic) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + + auto value = ComputeDynamismScalar(client, computation, &b); + ASSERT_TRUE(value.ok()) << value.status(); + // A parameter is considered dynamic. + EXPECT_EQ(value.ValueOrDie(), true); + } +} + +TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + + auto neg0 = Neg(c); + auto neg1 = Neg(p); + auto tuple_2 = Tuple(&b, {neg0, neg1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + + // Static value + static value = static + auto add1 = Add(c, c); + // Dynamic value + dynamic value = dynamic + auto add2 = Add(p, c); + auto tuple_2 = Tuple(&b, {add1, add2}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + true); + } +} + +TEST_F(DynamismInferenceTest, GetDimensionSize) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + // param = Param([<=2, 3]) + // get_dimension_size(param, 0) is dynamic + // get_dimension_size(param, 1) is static + auto p = + Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0"); + + auto gds0 = GetDimensionSize(p, 0); + auto gds1 = GetDimensionSize(p, 1); + auto tuple_2 = Tuple(&b, {gds0, gds1}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {0}).ValueOrDie(), + true); + EXPECT_EQ(ComputeDynamismScalar(client, tuple_2, &b, {1}).ValueOrDie(), + false); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index b83fed07e34..201c0da87f1 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -103,7 +103,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values( // The largest negative number smaller than zero in bf16 that's not // denormalized. - std::make_pair(static_cast(-bfloat16::min_positive_normal()), + std::make_pair(static_cast( + -std::numeric_limits::min()), 0.0f), // Test odd and even values. std::make_pair(32.75f, 33.00f), std::make_pair(32.50f, 32.75f), diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index b113b498e22..fc1ca7d3105 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -308,17 +308,18 @@ cc_library( ":prepare_reference_module", "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:error_spec", + "//tensorflow/compiler/xla:literal_comparison", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client/lib:testing", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:path", "//tensorflow/core/platform:status", - "//tensorflow/core/platform:test", "//tensorflow/stream_executor:platform", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -339,6 +340,7 @@ tf_cc_binary( "//tensorflow/core:framework_internal", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:platform_port", + "//tensorflow/core/platform:path", "//tensorflow/core/platform:status", "//tensorflow/core/platform:test", ] + if_cuda_or_rocm([ diff --git a/tensorflow/compiler/xla/tools/hlo_module_loader.cc b/tensorflow/compiler/xla/tools/hlo_module_loader.cc index b3aaba7fa25..8b70b0d35a7 100644 --- a/tensorflow/compiler/xla/tools/hlo_module_loader.cc +++ b/tensorflow/compiler/xla/tools/hlo_module_loader.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include -#include "google/protobuf/text_format.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" @@ -32,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" namespace xla { @@ -87,9 +87,10 @@ StatusOr> LoadModuleFromData( return InvalidArgument("Failed to parse input as HLO protobuf binary"); } } else if (format == "pbtxt") { - if (!google::protobuf::TextFormat::ParseFromString(data, &proto) && - !google::protobuf::TextFormat::ParseFromString(data, proto.mutable_hlo()) && - !google::protobuf::TextFormat::ParseFromString( + if (!tensorflow::protobuf::TextFormat::ParseFromString(data, &proto) && + !tensorflow::protobuf::TextFormat::ParseFromString( + data, proto.mutable_hlo()) && + !tensorflow::protobuf::TextFormat::ParseFromString( data, proto.mutable_hlo()->mutable_hlo_module())) { return InvalidArgument("Failed to parse input as HLO protobuf text"); } diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.cc b/tensorflow/compiler/xla/tools/run_hlo_module.cc index 39b545af393..be9b23efb12 100644 --- a/tensorflow/compiler/xla/tools/run_hlo_module.cc +++ b/tensorflow/compiler/xla/tools/run_hlo_module.cc @@ -27,24 +27,66 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/testing.h" #include "tensorflow/compiler/xla/debug_options_flags.h" #include "tensorflow/compiler/xla/error_spec.h" +#include "tensorflow/compiler/xla/literal_comparison.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/tools/hlo_module_loader.h" #include "tensorflow/compiler/xla/tools/prepare_reference_module.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/test.h" - -namespace se = ::stream_executor; namespace xla { namespace { +// Writes the given literal to a file in the test temporary directory. +void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) { + // Bazel likes for tests to write "debugging outputs" like these to + // TEST_UNDECLARED_OUTPUTS_DIR. This plays well with tools that inspect test + // results, especially when they're run on remote machines. + auto* env = tensorflow::Env::Default(); + string binary_filename; + string text_filename; + string outdir; + if (tensorflow::io::GetTestUndeclaredOutputsDir(&outdir)) { + string filename = tensorflow::io::JoinPath( + outdir, absl::StrFormat("tempfile-%d-%s", env->NowMicros(), name)); + binary_filename = absl::StrCat(filename, ".pb"); + text_filename = absl::StrCat(filename, ".txt"); + } else { + binary_filename = + tensorflow::io::GetTempFilename(absl::StrCat(name, ".pb")); + text_filename = tensorflow::io::GetTempFilename(absl::StrCat(name, ".txt")); + } + + TF_CHECK_OK( + tensorflow::WriteBinaryProto(env, binary_filename, literal.ToProto())); + TF_CHECK_OK( + tensorflow::WriteStringToFile(env, text_filename, literal.ToString())); + LOG(ERROR) << "wrote Literal to " << name << " binary: " << binary_filename + << " text: " << text_filename; +} + +// Callback helper that dumps literals to temporary files in the event of a +// miscomparison. +void OnMiscompare(const LiteralSlice& expected, const LiteralSlice& actual, + const LiteralSlice& mismatches, + const ShapeIndex& /*shape_index*/) { + LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) << " " + << literal_comparison::ToStringTruncated(expected); + LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) << " " + << literal_comparison::ToStringTruncated(actual); + LOG(INFO) << "Dumping literals to temp files..."; + WriteLiteralToTempFile(expected, "expected"); + WriteLiteralToTempFile(actual, "actual"); + WriteLiteralToTempFile(mismatches, "mismatches"); +} + Literal ExecuteOnPlatform(std::unique_ptr module, absl::Span args, se::Platform* platform, bool run_hlo_passes) { @@ -69,7 +111,7 @@ Literal ExecuteOnPlatform(std::unique_ptr module, } } // namespace -::testing::AssertionResult RunAndCompare( +Status RunAndCompare( const std::string& hlo_filename, const std::string& test_platform_name, const std::string& reference_platform_name, std::minstd_rand0* engine, const RunHloModuleOptions& options, @@ -122,7 +164,7 @@ Literal ExecuteOnPlatform(std::unique_ptr module, if (reference_module == nullptr) { std::cerr << "Skipping reference platform\n"; - return ::testing::AssertionSuccess(); + return Status::OK(); } Literal reference_result = @@ -136,10 +178,10 @@ Literal ExecuteOnPlatform(std::unique_ptr module, } ErrorSpec error_spec(static_cast(options.abs_error_bound), static_cast(options.rel_error_bound)); - return LiteralTestUtil::Near(/*expected=*/reference_result, - /*actual=*/test_result, - /*error_spec=*/error_spec, - /*detailed_message=*/true); + return literal_comparison::Near(/*expected=*/reference_result, + /*actual=*/test_result, + /*error=*/error_spec, + /*detailed_message=*/true, &OnMiscompare); } } // namespace xla diff --git a/tensorflow/compiler/xla/tools/run_hlo_module.h b/tensorflow/compiler/xla/tools/run_hlo_module.h index 932cc22f4dd..57f81cc7c94 100644 --- a/tensorflow/compiler/xla/tools/run_hlo_module.h +++ b/tensorflow/compiler/xla/tools/run_hlo_module.h @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/stream_executor/platform.h" namespace xla { @@ -63,7 +62,7 @@ struct RunHloModuleOptions { // the results. 'reference_module_modifier_hook' can be used to transform the // HloModule before it is run on the reference platform. This may be necessary // to match the numerics of the test platform. -::testing::AssertionResult RunAndCompare( +Status RunAndCompare( const std::string& hlo_filename, const std::string& test_platform_name, const std::string& reference_platform_name, std::minstd_rand0* engine, const RunHloModuleOptions& options, diff --git a/tensorflow/compiler/xla/tools/run_hlo_module_main.cc b/tensorflow/compiler/xla/tools/run_hlo_module_main.cc index 39d7826e162..9d153491862 100644 --- a/tensorflow/compiler/xla/tools/run_hlo_module_main.cc +++ b/tensorflow/compiler/xla/tools/run_hlo_module_main.cc @@ -156,7 +156,7 @@ int main(int argc, char** argv) { if (iteration_count != 1) { std::cerr << "\n=== Iteration " << i << "\n"; } - ::testing::AssertionResult matched = + xla::Status matched = xla::RunAndCompare(hlo_filename, test_platform_name, reference_platform_name, &engine, opts); @@ -164,13 +164,13 @@ int main(int argc, char** argv) { // used. Without a reference, the test just verifies that nothing blew up // when running the module. if (!reference_platform_name.empty()) { - if (matched) { + if (matched.ok()) { // Success. std::cerr << "\n** Results on " << test_platform_name << " and " << reference_platform_name << " are close enough. **\n"; } else { failure_count++; - std::cerr << matched.message() << "\n"; + std::cerr << matched << "\n"; } } } diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 42b6ea6bd53..1cf30b10373 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -270,8 +270,8 @@ message DebugOptions { // Paths to files with ptx code. repeated string xla_gpu_ptx_file = 127; - // Blacklist for cuDNN convolutions. - string xla_gpu_algorithm_blacklist_path = 128; + // Denylist for cuDNN convolutions. + string xla_gpu_algorithm_denylist_path = 128; // Guarantee run-to-run determinism from reductions on XLA:GPU. bool xla_gpu_deterministic_reductions = 130; @@ -349,6 +349,10 @@ message ExecutionOptions { // Indicates whether to use SPMD (true) or MPMD (false) partitioning when // num_partitions > 1 and XLA is requested to partition the input program. bool use_spmd_partitioning = 11; + + // If set, deduplicate hlo into function calls to reduce binary size. Only + // works on TPU. + bool deduplicate_hlo = 12; } message GetDeviceHandlesRequest { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index e8b6105d3fe..d334f879c3e 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -627,6 +627,11 @@ message OpSharding { // applied, this is inferred from the instruction this sharding gets attached // to. repeated OpSharding tuple_shardings = 5; + + // Only used for OTHER type. If true, data is sharded according to other + // dimensions of tile_assignment(), but replicated across devices along the + // last dimension. (Experimental) + bool replicate_on_last_tile_dim = 6; } // Describes the replica groups in a cross replica op (e.g., all-reduce and diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index c2f9a1c62c9..c4094795a96 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -650,7 +650,7 @@ Status XRTTupleAllocation::AliasBufferFrom(const XRTTupleAllocation& source, xla::StatusOr XRTTupleAllocation::ToExecutionInput( const std::function(const xla::ShapeIndex&)>& alias_checker) { - xla::ExecutionInput result(on_device_shape()); + xla::ExecutionInput result(on_device_shape(), on_host_shape()); for (const auto& index_buffer : buffers_) { if (index_buffer.second == nullptr || (index_buffer.second->allocation().is_null() && diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 18341a81df4..12e143e7933 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -458,7 +458,7 @@ tf_cuda_library( "//tensorflow/core/framework:common_shape_fns.h", "//tensorflow/core/framework:control_flow.h", # TODO(josh11b): Make internal? "//tensorflow/core/framework:dataset.h", - "//tensorflow/core/framework:dataset_stateful_op_whitelist.h", + "//tensorflow/core/framework:dataset_stateful_op_allowlist.h", "//tensorflow/core/framework:device_base.h", "//tensorflow/core/framework:function.h", "//tensorflow/core/framework:function_handle_cache.h", @@ -629,8 +629,8 @@ tf_gen_op_libs( "io_ops", "linalg_ops", "list_ops", + "map_ops", "lookup_ops", - "logging_ops", "manip_ops", "math_ops", "mkl_nn_ops", @@ -664,6 +664,19 @@ tf_gen_op_libs( ], ) +tf_gen_op_libs( + is_external = False, + op_lib_names = [ + "logging_ops", + ], + deps = [ + ":lib", + ":protos_all_cc", + # TODO(b/162630222): remove this dependency. + "//tensorflow/c/kernels:summary_op_lib", + ], +) + tf_gen_op_libs( op_lib_names = [ "string_ops", @@ -798,36 +811,7 @@ tf_gen_op_libs( "ragged_conversion_ops", "ragged_math_ops", ], - deps = [":ragged_to_dense_util"], -) - -cc_library( - name = "ragged_to_dense_util", - srcs = [ - "ops/ragged_to_dense_util.cc", - ], - hdrs = [ - "ops/ragged_to_dense_util.h", - ], - deps = [ - ":framework", - ":protos_all_cc", - ], -) - -tf_cc_test( - name = "ragged_to_dense_util_test", - srcs = [ - "ops/ragged_to_dense_util_test.cc", - ], - deps = [ - ":framework", - ":protos_all_cc", - ":ragged_to_dense_util", - ":test", - ":testlib", - "@com_google_googletest//:gtest_main", - ], + deps = ["//tensorflow/core/util:ragged_to_dense_util"], ) cc_library( @@ -860,6 +844,7 @@ cc_library( ":io_ops_op_lib", ":linalg_ops_op_lib", ":list_ops_op_lib", + ":map_ops_op_lib", ":logging_ops_op_lib", ":lookup_ops_op_lib", ":manip_ops_op_lib", @@ -892,6 +877,7 @@ cc_library( ":user_ops_op_lib", ":word2vec_ops", "//tensorflow/c/kernels:bitcast_op_lib", + "//tensorflow/c/kernels:summary_op_lib", "//tensorflow/compiler/mlir/tensorflow:mlir_passthrough_op", ] + if_chromiumos( [], @@ -909,6 +895,7 @@ cc_library( ":tpu_outfeed_ops_op_lib", ":tpu_ordinal_selector_ops_op_lib", ":tpu_replication_ops_op_lib", + "//tensorflow/core/tpu/ops", ], ) + if_mkl([ ":mkl_array_ops_op_lib", @@ -998,6 +985,7 @@ cc_library( name = "all_kernels_impl", visibility = [":__subpackages__"], deps = [ + "//tensorflow/c/kernels:summary_op", "//tensorflow/c/kernels:bitcast_op", "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:audio", @@ -1022,9 +1010,7 @@ cc_library( "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/kernels:grappler", "//tensorflow/core/kernels:histogram_op", - "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", - "//tensorflow/core/kernels:linalg", "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:manip", @@ -1058,32 +1044,34 @@ cc_library( "//tensorflow/core/kernels:summary_kernels", "//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:word2vec_kernels", + "//tensorflow/core/kernels/linalg:linalg", + "//tensorflow/core/kernels/image:image", "//tensorflow/core/kernels/sparse:kernels", ] + if_not_windows([ "//tensorflow/core/kernels/neon:neon_depthwise_conv_op", ]) + if_mkl([ - "//tensorflow/core/kernels:mkl_aggregate_ops", - "//tensorflow/core/kernels:mkl_concat_op", - "//tensorflow/core/kernels:mkl_dequantize_op", - "//tensorflow/core/kernels:mkl_conv_op", - "//tensorflow/core/kernels:mkl_cwise_ops_common", - "//tensorflow/core/kernels:mkl_fused_batch_norm_op", - "//tensorflow/core/kernels:mkl_identity_op", - "//tensorflow/core/kernels:mkl_input_conversion_op", - "//tensorflow/core/kernels:mkl_lrn_op", - "//tensorflow/core/kernels:mkl_pooling_ops", - "//tensorflow/core/kernels:mkl_qmatmul_op", - "//tensorflow/core/kernels:mkl_requantize_ops", - "//tensorflow/core/kernels:mkl_quantize_op", - "//tensorflow/core/kernels:mkl_relu_op", - "//tensorflow/core/kernels:mkl_reshape_op", - "//tensorflow/core/kernels:mkl_slice_op", - "//tensorflow/core/kernels:mkl_softmax_op", - "//tensorflow/core/kernels:mkl_transpose_op", - "//tensorflow/core/kernels:mkl_batch_matmul_op", - "//tensorflow/core/kernels:mkl_matmul_op", - "//tensorflow/core/kernels:mkl_tfconv_op", - "//tensorflow/core/kernels:mkl_tmp_bf16_ops", + "//tensorflow/core/kernels/mkl:mkl_aggregate_ops", + "//tensorflow/core/kernels/mkl:mkl_concat_op", + "//tensorflow/core/kernels/mkl:mkl_dequantize_op", + "//tensorflow/core/kernels/mkl:mkl_conv_op", + "//tensorflow/core/kernels/mkl:mkl_cwise_ops_common", + "//tensorflow/core/kernels/mkl:mkl_fused_batch_norm_op", + "//tensorflow/core/kernels/mkl:mkl_identity_op", + "//tensorflow/core/kernels/mkl:mkl_input_conversion_op", + "//tensorflow/core/kernels/mkl:mkl_lrn_op", + "//tensorflow/core/kernels/mkl:mkl_pooling_ops", + "//tensorflow/core/kernels/mkl:mkl_qmatmul_op", + "//tensorflow/core/kernels/mkl:mkl_requantize_ops", + "//tensorflow/core/kernels/mkl:mkl_quantize_op", + "//tensorflow/core/kernels/mkl:mkl_relu_op", + "//tensorflow/core/kernels/mkl:mkl_reshape_op", + "//tensorflow/core/kernels/mkl:mkl_slice_op", + "//tensorflow/core/kernels/mkl:mkl_softmax_op", + "//tensorflow/core/kernels/mkl:mkl_transpose_op", + "//tensorflow/core/kernels/mkl:mkl_batch_matmul_op", + "//tensorflow/core/kernels/mkl:mkl_matmul_op", + "//tensorflow/core/kernels/mkl:mkl_tfconv_op", + "//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops", ]) + if_cuda_or_rocm([ "//tensorflow/core/kernels:cudnn_rnn_kernels", ]) + if_cuda([ @@ -1121,6 +1109,8 @@ cc_library( # these also dynamically loading. "//tensorflow/core/kernels:dataset_ops", # Depends on grappler "//tensorflow/core/kernels:list_kernels", # Depends on variant_op_registry.h + "//tensorflow/core/kernels:map_kernels", + "//tensorflow/core/kernels:tensor_map", ], ) @@ -1927,6 +1917,7 @@ cc_library( "//tensorflow/core/platform:platform_port", "//tensorflow/core/platform:platform_strings", "//tensorflow/core/platform:prefetch", + "//tensorflow/core/platform:profile_utils_cpu_utils", "//tensorflow/core/platform:protobuf_internal", "//tensorflow/core/platform:regexp", "//tensorflow/core/platform:resource", @@ -1986,6 +1977,7 @@ cc_library( ":lib", ":lib_internal", "//tensorflow/core/platform:gif", + "@com_google_absl//absl/strings", ], ) @@ -2093,13 +2085,9 @@ cc_library( copts = tf_copts(), linkopts = ["-ldl"], deps = [ - "//tensorflow/core/lib/strings:numbers", - "//tensorflow/core/lib/strings:strcat", "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:gif", "//tensorflow/core/platform:logging", - "//tensorflow/core/platform:numbers", - "//tensorflow/core/platform:strcat", "//tensorflow/core/platform:stringpiece", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", @@ -2706,27 +2694,27 @@ tf_cc_test_mkl( "//tensorflow/core/kernels:ops_util", "//third_party/eigen3", ] + if_mkl([ - "//tensorflow/core/kernels:mkl_aggregate_ops", - "//tensorflow/core/kernels:mkl_batch_matmul_op", - "//tensorflow/core/kernels:mkl_concat_op", - "//tensorflow/core/kernels:mkl_conv_op", - "//tensorflow/core/kernels:mkl_cwise_ops_common", - "//tensorflow/core/kernels:mkl_dequantize_op", - "//tensorflow/core/kernels:mkl_fused_batch_norm_op", - "//tensorflow/core/kernels:mkl_identity_op", - "//tensorflow/core/kernels:mkl_input_conversion_op", - "//tensorflow/core/kernels:mkl_lrn_op", - "//tensorflow/core/kernels:mkl_matmul_op", - "//tensorflow/core/kernels:mkl_pooling_ops", - "//tensorflow/core/kernels:mkl_qmatmul_op", - "//tensorflow/core/kernels:mkl_quantize_op", - "//tensorflow/core/kernels:mkl_relu_op", - "//tensorflow/core/kernels:mkl_reshape_op", - "//tensorflow/core/kernels:mkl_slice_op", - "//tensorflow/core/kernels:mkl_softmax_op", - "//tensorflow/core/kernels:mkl_tfconv_op", - "//tensorflow/core/kernels:mkl_transpose_op", - "//tensorflow/core/kernels:mkl_tmp_bf16_ops", + "//tensorflow/core/kernels/mkl:mkl_aggregate_ops", + "//tensorflow/core/kernels/mkl:mkl_batch_matmul_op", + "//tensorflow/core/kernels/mkl:mkl_concat_op", + "//tensorflow/core/kernels/mkl:mkl_conv_op", + "//tensorflow/core/kernels/mkl:mkl_cwise_ops_common", + "//tensorflow/core/kernels/mkl:mkl_dequantize_op", + "//tensorflow/core/kernels/mkl:mkl_fused_batch_norm_op", + "//tensorflow/core/kernels/mkl:mkl_identity_op", + "//tensorflow/core/kernels/mkl:mkl_input_conversion_op", + "//tensorflow/core/kernels/mkl:mkl_lrn_op", + "//tensorflow/core/kernels/mkl:mkl_matmul_op", + "//tensorflow/core/kernels/mkl:mkl_pooling_ops", + "//tensorflow/core/kernels/mkl:mkl_qmatmul_op", + "//tensorflow/core/kernels/mkl:mkl_quantize_op", + "//tensorflow/core/kernels/mkl:mkl_relu_op", + "//tensorflow/core/kernels/mkl:mkl_reshape_op", + "//tensorflow/core/kernels/mkl:mkl_slice_op", + "//tensorflow/core/kernels/mkl:mkl_softmax_op", + "//tensorflow/core/kernels/mkl:mkl_tfconv_op", + "//tensorflow/core/kernels/mkl:mkl_transpose_op", + "//tensorflow/core/kernels/mkl:mkl_tmp_bf16_ops", ]), ) diff --git a/tensorflow/core/api_def/base_api/api_def_AvgPool3D.pbtxt b/tensorflow/core/api_def/base_api/api_def_AvgPool3D.pbtxt index 8171566a212..fcaa93acac1 100644 --- a/tensorflow/core/api_def/base_api/api_def_AvgPool3D.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_AvgPool3D.pbtxt @@ -43,4 +43,8 @@ Alternatively, the format could be "NCDHW", the data storage order is: END } summary: "Performs 3D average pooling on the input." + description: <